# 

In [3]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np # For best-fit line

# --- 2. Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    """Fetches all column names for a table."""
    if not table_name:
        return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

# --- 3. Create GUI Widgets ---
table_dropdown = widgets.Dropdown(description='Select Table:')
x_axis_dropdown = widgets.Dropdown(description='X-Axis:')
y_axis_dropdown = widgets.Dropdown(description='Y-Axis:')
plot_type_dropdown = widgets.Dropdown(options=['Scatter', 'Bar'], description='Plot Type:')
best_fit_checkbox = widgets.Checkbox(value=True, description='Add Best Fit Curve')
generate_button = widgets.Button(description='Generate Plot')
plot_output = widgets.Output()

# --- 4. Define Event Handlers ---
def on_table_change(change):
    """Called when the table dropdown changes."""
    table = change['new']
    columns = get_column_names(table)
    x_axis_dropdown.options = columns
    y_axis_dropdown.options = columns
    if columns:
        x_axis_dropdown.value = columns[0]
        y_axis_dropdown.value = columns[1] if len(columns) > 1 else columns[0]

def on_plot_type_change(change):
    """Called when the plot type dropdown changes."""
    if change['new'] == 'Scatter':
        best_fit_checkbox.layout.display = 'flex'
        y_axis_dropdown.description = 'Y-Axis:'
    else:
        best_fit_checkbox.layout.display = 'none'
        y_axis_dropdown.description = 'Y-Axis (for Count):'

def on_generate_click(b):
    """Called when the 'Generate Plot' button is clicked."""
    with plot_output:
        clear_output(wait=True)  # Clear the old plot
        
        try:
            # --- Get all user selections ---
            table = table_dropdown.value
            x_col = x_axis_dropdown.value
            y_col = y_axis_dropdown.value 
            plot_type = plot_type_dropdown.value
            show_best_fit = best_fit_checkbox.value

            # --- Fetch data (including 'age' for filtering) ---
            all_columns = get_column_names(table)
            cols_to_fetch = set([x_col, y_col])
            if 'age' in all_columns:
                cols_to_fetch.add('age')
            
            col_str = '", "'.join(cols_to_fetch)
            query = f'SELECT "{col_str}" FROM "{table}"'
            
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(query, conn)
            conn.close()

            # --- Apply Age Filter (18-22) ---
            if 'age' in df.columns:
                df['age'] = pd.to_numeric(df['age'], errors='coerce')
                df = df.dropna(subset=['age'])
                df['age'] = df['age'].astype(float)
                
                df_filtered = df[(df['age'] >= 18) & (df['age'] <= 22)]
                
                if not df_filtered.empty:
                    df = df_filtered
                    print(f"Filter applied: Showing {len(df)} students aged 18-22.")
                else:
                    print("Note: No data found for ages 18-22. Showing all ages.")

            # --- Clean data based on plot type ---
            df_cleaned = df.copy()
            if plot_type == 'Scatter':
                df_cleaned[x_col] = pd.to_numeric(df_cleaned[x_col], errors='coerce')
                df_cleaned[y_col] = pd.to_numeric(df_cleaned[y_col], errors='coerce')
                df_cleaned = df_cleaned.dropna(subset=[x_col, y_col])
            elif plot_type == 'Bar':
                df_cleaned = df_cleaned.dropna(subset=[x_col])
            
            if df_cleaned.empty:
                print("Error: No valid data to plot after cleaning/filtering.")
                return

            # --- Plotting ---
            fig, ax = plt.subplots()
            
            y_label = y_col if plot_type == 'Scatter' else 'Count'
            title = f"{y_col} vs. {x_col}" if plot_type == 'Scatter' else f"Count of Students vs. {x_col}"
            
            if plot_type == 'Scatter':
                ax.scatter(df_cleaned[x_col], df_cleaned[y_col], alpha=0.7)
                if show_best_fit and len(df_cleaned[x_col]) > 1:
                    m, b = np.polyfit(df_cleaned[x_col], df_cleaned[y_col], 1)
                    ax.plot(df_cleaned[x_col], m * df_cleaned[x_col] + b, color='red')
            
            elif plot_type == 'Bar':
                grouped_data = df_cleaned.groupby(df_cleaned[x_col].astype(str))[y_col].count()
                ax.bar(grouped_data.index, grouped_data.values)
                if len(grouped_data) > 20:
                    ax.set_xticks([])
                else:
                    plt.xticks(rotation=45, ha='right')

            # --- Set labels and save the plot ---
            ax.set_xlabel(x_col)
            ax.set_ylabel(y_label) 
            ax.set_title(f"{title} (Table: {table})")
            ax.grid(True)
            plt.tight_layout()
            
            buf = io.BytesIO()
            fig.savefig(buf, format='png')
            buf.seek(0)
            
            image_widget = widgets.Image(value=buf.read(), format='png')
            display(image_widget)
            
            buf.close()
            plt.close(fig)

        except Exception as e:
            print(f"An unexpected error occurred: {e}")

# --- 5. Wire Up Widgets ---
table_dropdown.observe(on_table_change, names='value')
plot_type_dropdown.observe(on_plot_type_change, names='value')
generate_button.on_click(on_generate_click)

# --- 6. Create Layout and Display GUI ---
grid = GridspecLayout(3, 2)
grid[0, 0] = table_dropdown
grid[0, 1] = plot_type_dropdown
grid[1, 0] = x_axis_dropdown
grid[1, 1] = y_axis_dropdown
grid[2, 0] = generate_button
grid[2, 1] = best_fit_checkbox

gui_layout = widgets.VBox([grid, plot_output])

# --- 7. Initial Population ---
table_names = get_table_names()
if table_names:
    table_dropdown.options = table_names
    table_dropdown.value = table_names[0]
else:
    with plot_output:
        print("Warning: 'Dataset.db' not found or is empty.")

on_plot_type_change({'new': plot_type_dropdown.value})

# --- 8. Display the GUI ---
display(gui_layout)

VBox(children=(GridspecLayout(children=(Dropdown(description='Select Table:', layout=Layout(grid_area='widget0…

In [2]:
import sqlite3
import csv
import sys

DB_NAME = 'Dataset.db'
CSV_NAME = 'StressLevelDataset.csv'

try:
    # --- 1. Read all data from the CSV ---
    data_rows = []
    with open(CSV_NAME, 'r', newline='') as csvfile:
        reader = csv.reader(csvfile)
        header = next(reader) # Skip the header row
        
        for row in reader:
            data_rows.append(row)
    
    print(f"Read {len(data_rows)} rows from {CSV_NAME}.")
    
    # Check if data matches the 843 student_ids in the Students table
    if len(data_rows) != 843:
        print(f"Warning: CSV has {len(data_rows)} rows, but 'Students' table has 843. This may cause issues.")

    # --- 2. Connect to the database ---
    conn = sqlite3.connect(DB_NAME)
    cursor = conn.cursor()

    # --- 3. Clear old data from the tables ---
    print("Clearing old data from tables...")
    cursor.execute("DELETE FROM Academic;")
    cursor.execute("DELETE FROM Psychological;")
    cursor.execute("DELETE FROM Physiological;")

    # --- 4. Define column mappings (CSV index -> DB column) ---
    # These indices are based on the 'StressLevelDataset.csv' headers
    psychological_cols = {
        'anxiety_level': 0,
        'self_esteem': 1,
        'mental_health_history': 2,
        'depression': 3
    }
    
    physiological_cols = {
        'headache': 4,
        'blood_pressure': 5,
        'sleep_quality': 6,
        'breathing_problem': 7
    }
    
    academic_cols = {
        'academic_performance': 12,
        'study_load': 13,
        'teacher_student_relationship': 14,
        'future_career_concerns': 15
    }

    # --- 5. Loop through each row and insert data ---
    student_id = 1
    for row in data_rows:
        try:
            # Insert into Psychological
            p_data = (
                student_id, 
                row[psychological_cols['anxiety_level']],
                row[psychological_cols['self_esteem']],
                row[psychological_cols['mental_health_history']],
                row[psychological_cols['depression']]
            )
            cursor.execute("INSERT INTO Psychological (student_id, anxiety_level, self_esteem, mental_health_history, depression) VALUES (?, ?, ?, ?, ?);", p_data)

            # Insert into Physiological
            phys_data = (
                student_id,
                row[physiological_cols['blood_pressure']],
                row[physiological_cols['breathing_problem']],
                row[physiological_cols['sleep_quality']],
                row[physiological_cols['headache']]
            )
            cursor.execute("INSERT INTO Physiological (student_id, blood_pressure, breathing_problem, sleep_quality, headache) VALUES (?, ?, ?, ?, ?);", phys_data)
            
            # Insert into Academic
            a_data = (
                student_id,
                row[academic_cols['academic_performance']],
                row[academic_cols['study_load']],
                row[academic_cols['teacher_student_relationship']],
                row[academic_cols['future_career_concerns']]
            )
            cursor.execute("INSERT INTO Academic (student_id, academic_performance, study_load, teacher_student_relationship, future_career_concerns) VALUES (?, ?, ?, ?, ?);", a_data)
            
            student_id += 1
            
        except IndexError:
            print(f"Skipped row {student_id}. Data was incomplete.")
        except Exception as e:
            print(f"Error on row {student_id}: {e}")
            
    # --- 6. Commit changes and close ---
    conn.commit()
    conn.close()
    
    print("---")
    print(f"Successfully populated all tables for {student_id - 1} students.")
    print("Your database is now correct. Please re-run the GUI code cell.")

except FileNotFoundError:
    print(f"ERROR: '{CSV_NAME}' not found. Make sure it is in the same folder.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

Read 1100 rows from StressLevelDataset.csv.
Clearing old data from tables...
---
Successfully populated all tables for 1100 students.
Your database is now correct. Please re-run the GUI code cell.


In [4]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np # For best-fit line

# --- 2. Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    """Fetches all column names for a table."""
    if not table_name:
        return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

# --- 3. Create GUI Widgets ---
table_dropdown = widgets.Dropdown(description='Select Table:')
x_axis_dropdown = widgets.Dropdown(description='X-Axis:')
y_axis_dropdown = widgets.Dropdown(description='Y-Axis:')
plot_type_dropdown = widgets.Dropdown(options=['Scatter', 'Bar'], description='Plot Type:')
best_fit_checkbox = widgets.Checkbox(value=True, description='Add Best Fit Curve')
generate_button = widgets.Button(description='Generate Plot')
plot_output = widgets.Output()

# --- 4. Define Event Handlers ---
def on_table_change(change):
    """Called when the table dropdown changes."""
    table = change['new']
    columns = get_column_names(table)
    x_axis_dropdown.options = columns
    y_axis_dropdown.options = columns
    if columns:
        x_axis_dropdown.value = columns[0]
        y_axis_dropdown.value = columns[1] if len(columns) > 1 else columns[0]

def on_plot_type_change(change):
    """Called when the plot type dropdown changes."""
    if change['new'] == 'Scatter':
        best_fit_checkbox.layout.display = 'flex'
        y_axis_dropdown.description = 'Y-Axis:'
    else:
        best_fit_checkbox.layout.display = 'none'
        y_axis_dropdown.description = 'Y-Axis (for Count):'

def on_generate_click(b):
    """Called when the 'Generate Plot' button is clicked."""
    with plot_output:
        clear_output(wait=True)  # Clear the old plot
        
        try:
            # --- Get all user selections ---
            table = table_dropdown.value
            x_col = x_axis_dropdown.value
            y_col = y_axis_dropdown.value 
            plot_type = plot_type_dropdown.value
            show_best_fit = best_fit_checkbox.value

            # --- Fetch data (including 'age' for filtering) ---
            all_columns = get_column_names(table)
            cols_to_fetch = set([x_col, y_col])
            if 'age' in all_columns:
                cols_to_fetch.add('age')
            
            col_str = '", "'.join(cols_to_fetch)
            query = f'SELECT "{col_str}" FROM "{table}"'
            
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(query, conn)
            conn.close()

            # --- Apply Strict Age Filter (18-22) ---
            if 'age' in df.columns:
                df['age'] = pd.to_numeric(df['age'], errors='coerce')
                df = df.dropna(subset=['age'])
                df['age'] = df['age'].astype(float)
                
                df_filtered = df[(df['age'] >= 18) & (df['age'] <= 22)]
                
                if not df_filtered.empty:
                    df = df_filtered
                    print(f"Filter applied: Showing {len(df)} students aged 18-22.")
                else:
                    # Set df to the empty filtered list
                    df = df_filtered 
                    print("Filter applied: No students found between ages 18-22.")

            # --- Check if data is empty AFTER filtering ---
            if df.empty:
                print("Error: No data to plot after applying filters.")
                return

            # --- Clean data based on plot type ---
            df_cleaned = df.copy()
            if plot_type == 'Scatter':
                df_cleaned[x_col] = pd.to_numeric(df_cleaned[x_col], errors='coerce')
                df_cleaned[y_col] = pd.to_numeric(df_cleaned[y_col], errors='coerce')
                df_cleaned = df_cleaned.dropna(subset=[x_col, y_col])
            elif plot_type == 'Bar':
                df_cleaned = df_cleaned.dropna(subset=[x_col])
            
            if df_cleaned.empty:
                print("Error: No valid data to plot after cleaning.")
                return

            # --- Plotting ---
            fig, ax = plt.subplots()
            
            y_label = y_col if plot_type == 'Scatter' else 'Count'
            title = f"{y_col} vs. {x_col}" if plot_type == 'Scatter' else f"Count of Students vs. {x_col}"
            
            if plot_type == 'Scatter':
                ax.scatter(df_cleaned[x_col], df_cleaned[y_col], alpha=0.7)
                if show_best_fit and len(df_cleaned[x_col]) > 1:
                    m, b = np.polyfit(df_cleaned[x_col], df_cleaned[y_col], 1)
                    ax.plot(df_cleaned[x_col], m * df_cleaned[x_col] + b, color='red')
            
            elif plot_type == 'Bar':
                grouped_data = df_cleaned.groupby(df_cleaned[x_col].astype(str))[y_col].count()
                ax.bar(grouped_data.index, grouped_data.values)
                if len(grouped_data) > 20:
                    ax.set_xticks([])
                else:
                    plt.xticks(rotation=45, ha='right')

            # --- Set labels and save the plot ---
            ax.set_xlabel(x_col)
            ax.set_ylabel(y_label) 
            ax.set_title(f"{title} (Table: {table})")
            ax.grid(True)
            plt.tight_layout()
            
            buf = io.BytesIO()
            fig.savefig(buf, format='png')
            buf.seek(0)
            
            image_widget = widgets.Image(value=buf.read(), format='png')
            display(image_widget)
            
            buf.close()
            plt.close(fig)

        except Exception as e:
            print(f"An unexpected error occurred: {e}")

# --- 5. Wire Up Widgets ---
table_dropdown.observe(on_table_change, names='value')
plot_type_dropdown.observe(on_plot_type_change, names='value')
generate_button.on_click(on_generate_click)

# --- 6. Create Layout and Display GUI ---
grid = GridspecLayout(3, 2)
grid[0, 0] = table_dropdown
grid[0, 1] = plot_type_dropdown
grid[1, 0] = x_axis_dropdown
grid[1, 1] = y_axis_dropdown
grid[2, 0] = generate_button
grid[2, 1] = best_fit_checkbox

gui_layout = widgets.VBox([grid, plot_output])

# --- 7. Initial Population ---
table_names = get_table_names()
if table_names:
    table_dropdown.options = table_names
    table_dropdown.value = table_names[0]
else:
    with plot_output:
        print("Warning: 'Dataset.db' not found or is empty.")

on_plot_type_change({'new': plot_type_dropdown.value})

# --- 8. Display the GUI ---
display(gui_layout)

VBox(children=(GridspecLayout(children=(Dropdown(description='Select Table:', layout=Layout(grid_area='widget0…

In [5]:
import sqlite3
conn = sqlite3.connect('Dataset.db')
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
print(cursor.fetchall())
conn.close()


[('sqlite_sequence',), ('Students',), ('Academic',), ('Psychological',), ('Physiological',)]


In [6]:
print(get_column_names('Students'))


['student_id', 'age', 'gender']


In [7]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
# Make sure your 'Dataset.db' file is in the same folder as this notebook
DB_NAME = 'Dataset.db'

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. When user picks a table ---
def update_columns(change):
    clear_output(wait=True)
    display(ui)

    table = table_dropdown.value
    ptype = plot_type.value

    # Load sample data
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 100', conn)
    conn.close()

    # Choose columns based on plot type
    if ptype == "Scatter":
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
    else:
        all_cols = df.columns.tolist()
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols

# --- 6. When user clicks "Generate Plot" ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)

        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value

        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid columns.")
            return

        # Load full table data
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # Filter by age if available
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- Create Plot ---
        fig, ax = plt.subplots()

        if ptype == "Scatter":
            # Convert to numbers and drop missing values
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col])

            ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

            # Add best-fit line if checked
            if best_fit.value and len(df) > 1:
                m, b = np.polyfit(df[x_col], df[y_col], 1)
                ax.plot(df[x_col], m * df[x_col] + b, color='red')
            ax.set_ylabel(y_col)

        else:  # Bar plot
            grouped = df.groupby(x_col)[y_col].count()
            grouped.plot(kind='bar', ax=ax, color='skyblue')
            plt.xticks(rotation=45, ha='right')
            ax.set_ylabel("Count")

        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True)
        plt.tight_layout()

        # Show image inside widget
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        plt.close(fig)

# --- 7. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 8. Display Interface ---
ui = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_dropdown, y_dropdown]),
    widgets.HBox([generate_button, best_fit]),
    output_area
])
display(ui)


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [8]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'  # Ensure your database file is in the same directory

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    clear_output(wait=True)
    display(ui)

    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 100', conn)
    conn.close()

    # Choose columns based on plot type
    if ptype == "Scatter":
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
    else:
        all_cols = df.columns.tolist()
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols

# --- 6. When user clicks "Generate Plot" ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)

        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value

        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid columns.")
            return

        # Load full table data
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # Filter by age if available
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- Create Plot ---
        fig, ax = plt.subplots(figsize=(6, 4))

        if ptype == "Scatter":
            # Convert to numbers and drop missing values
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col])

            ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

            # Add best-fit line if checked
            if best_fit.value and len(df) > 1:
                m, b = np.polyfit(df[x_col], df[y_col], 1)
                ax.plot(df[x_col], m * df[x_col] + b, color='red')
            ax.set_ylabel(y_col)

        else:  # --- Bar Plot ---
            # Handle grouped bars when both are categorical (e.g., gender vs. age)
            if df[x_col].dtype == 'object' or df[y_col].dtype == 'object':
                grouped = pd.crosstab(df[x_col], df[y_col])
                grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8)
                plt.xticks(rotation=45, ha='right')
                ax.set_ylabel("Count")
                ax.legend(title=y_col)
            else:
                # Default single-variable count bar
                grouped = df.groupby(x_col)[y_col].count()
                grouped.plot(kind='bar', ax=ax, color='skyblue')
                plt.xticks(rotation=45, ha='right')
                ax.set_ylabel("Count")

        # --- Final Plot Formatting ---
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)
        plt.tight_layout()

        # Show image inside widget
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        plt.close(fig)

# --- 7. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 8. Display Interface ---
ui = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_dropdown, y_dropdown]),
    widgets.HBox([generate_button, best_fit]),
    output_area
])

display(ui)


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [9]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'  # Ensure your database file is in the same directory

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    clear_output(wait=True)
    display(ui)

    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 100', conn)
    conn.close()

    # Detect ID-like columns (likely unique identifiers)
    id_like = [col for col in df.columns if 'id' in col.lower() or df[col].nunique() == len(df)]

    # Choose columns for X and Y
    all_cols = df.columns.tolist()
    x_allowed = [col for col in all_cols if col not in id_like]  # exclude IDs only from X
    y_allowed = all_cols  # keep everything, including IDs

    if ptype == "Scatter":
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
        x_dropdown.options = [col for col in numeric_cols if col not in id_like]
        y_dropdown.options = numeric_cols
    else:
        x_dropdown.options = x_allowed
        y_dropdown.options = y_allowed

# --- 6. When user clicks "Generate Plot" ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)

        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value

        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid columns.")
            return

        # Load full table data
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # Filter by age if available
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- Create Plot ---
        fig, ax = plt.subplots(figsize=(6, 4))

        if ptype == "Scatter":
            # Convert to numbers and drop missing values
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col])

            ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

            # Add best-fit line if checked
            if best_fit.value and len(df) > 1:
                m, b = np.polyfit(df[x_col], df[y_col], 1)
                ax.plot(df[x_col], m * df[x_col] + b, color='red')
            ax.set_ylabel(y_col)

        else:  # --- Bar Plot ---
            # Handle grouped bars when both are categorical (e.g., gender vs. age)
            if df[x_col].dtype == 'object' or df[y_col].dtype == 'object':
                grouped = pd.crosstab(df[x_col], df[y_col])
                grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8)
                plt.xticks(rotation=45, ha='right')
                ax.set_ylabel("Count")
                ax.legend(title=y_col)
            else:
                # Default single-variable count bar
                grouped = df.groupby(x_col)[y_col].count()
                grouped.plot(kind='bar', ax=ax, color='skyblue')
                plt.xticks(rotation=45, ha='right')
                ax.set_ylabel("Count")

        # --- Final Plot Formatting ---
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)
        plt.tight_layout()

        # Show image inside widget
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        plt.close(fig)

# --- 7. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 8. Display Interface ---
ui = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_dropdown, y_dropdown]),
    widgets.HBox([generate_button, best_fit]),
    output_area
])

display(ui)


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [10]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'  # Ensure your database file is in the same directory

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    clear_output(wait=True)
    display(ui)

    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 100', conn)
    conn.close()

    # Choose columns based on plot type
    if ptype == "Scatter":
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
    else:
        all_cols = df.columns.tolist()
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols

# --- 6. When user clicks "Generate Plot" ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)

        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value

        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid columns.")
            return

        # Load full table data
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # Filter by age if available
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- Create Plot ---
        fig, ax = plt.subplots(figsize=(6, 4))

        if ptype == "Scatter":
            # Convert to numbers and drop missing values
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col])

            ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

            # Add best-fit line if checked
            if best_fit.value and len(df) > 1:
                m, b = np.polyfit(df[x_col], df[y_col], 1)
                ax.plot(df[x_col], m * df[x_col] + b, color='red')
            ax.set_ylabel(y_col)

        else:  # --- Bar Plot ---
            # Handle grouped bars when both are categorical (e.g., gender vs. age)
            if df[x_col].dtype == 'object' or df[y_col].dtype == 'object':
                grouped = pd.crosstab(df[x_col], df[y_col])
                grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8)
                plt.xticks(rotation=45, ha='right')
                ax.set_ylabel("Count")
                ax.legend(title=y_col)
            else:
                # Default single-variable count bar
                grouped = df.groupby(x_col)[y_col].count()
                grouped.plot(kind='bar', ax=ax, color='skyblue')
                plt.xticks(rotation=45, ha='right')
                ax.set_ylabel("Count")

        # --- Final Plot Formatting ---
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)
        plt.tight_layout()

        # Show image inside widget
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        plt.close(fig)

# --- 7. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 8. Display Interface ---
ui = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_dropdown, y_dropdown]),
    widgets.HBox([generate_button, best_fit]),
    output_area
])

display(ui)


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [11]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'  # Ensure your database file is in the same directory

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    clear_output(wait=True)
    display(ui)

    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 100', conn)
    conn.close()

    # Choose columns based on plot type
    if ptype == "Scatter":
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
    else:
        all_cols = df.columns.tolist()
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols

# --- 6. When user clicks "Generate Plot" ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)

        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value

        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid columns.")
            return

        # Load full table data
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # Filter by age if available
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- Create Plot ---
        fig, ax = plt.subplots(figsize=(6, 4))

        if ptype == "Scatter":
            # Convert to numbers and drop missing values
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col])

            ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

            # Add best-fit line if checked
            if best_fit.value and len(df) > 1:
                m, b = np.polyfit(df[x_col], df[y_col], 1)
                ax.plot(df[x_col], m * df[x_col] + b, color='red')
            ax.set_ylabel(y_col)

        else:  # --- Bar Plot ---
            # Handle grouped bars when both are categorical (e.g., gender vs. age)
            if df[x_col].dtype == 'object' or df[y_col].dtype == 'object':
                grouped = pd.crosstab(df[x_col], df[y_col])
                grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8)
                plt.xticks(rotation=45, ha='right')
                ax.set_ylabel("Count")
                ax.legend(title=y_col)
            else:
                # Default single-variable count bar
                grouped = df.groupby(x_col)[y_col].count()
                grouped.plot(kind='bar', ax=ax, color='skyblue')
                plt.xticks(rotation=45, ha='right')
                ax.set_ylabel("Count")

        # --- Final Plot Formatting ---
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)
        plt.tight_layout()

        # Show image inside widget
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        plt.close(fig)

# --- 7. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 8. Display Interface ---
ui = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_dropdown, y_dropdown]),
    widgets.HBox([generate_button, best_fit]),
    output_area
])

display(ui)


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [1]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'  # Ensure your database file is in the same directory

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    clear_output(wait=True)
    display(ui)

    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 100', conn)
    conn.close()

    # Choose columns based on plot type
    if ptype == "Scatter":
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
    else:
        all_cols = df.columns.tolist()
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols

# --- 6. When user clicks "Generate Plot" ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)

        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value

        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid columns.")
            return

        # Load full table data
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # Filter by age if available
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- Create Plot ---
        fig, ax = plt.subplots(figsize=(6, 4))

        if ptype == "Scatter":
            # Convert to numbers and drop missing values
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col])

            ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

            # Add best-fit line if checked
            if best_fit.value and len(df) > 1:
                m, b = np.polyfit(df[x_col], df[y_col], 1)
                ax.plot(df[x_col], m * df[x_col] + b, color='red')
            ax.set_ylabel(y_col)

        else:  # --- Bar Plot ---
            # Handle grouped bars when both are categorical (e.g., gender vs. age)
            if df[x_col].dtype == 'object' or df[y_col].dtype == 'object':
                grouped = pd.crosstab(df[x_col], df[y_col])
                grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8)
                plt.xticks(rotation=45, ha='right')
                ax.set_ylabel("Count")
                ax.legend(title=y_col)
            else:
                # Default single-variable count bar
                grouped = df.groupby(x_col)[y_col].count()
                grouped.plot(kind='bar', ax=ax, color='skyblue')
                plt.xticks(rotation=45, ha='right')
                ax.set_ylabel("Count")

        # --- Final Plot Formatting ---
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)
        plt.tight_layout()

        # Show image inside widget
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        plt.close(fig)

# --- 7. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 8. Display Interface ---
ui = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_dropdown, y_dropdown]),
    widgets.HBox([generate_button, best_fit]),
    output_area
])

display(ui)


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [2]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db' # Ensure your database file is in the same directory

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    # This keeps the UI from jumping around
    with output_area:
        pass 
    
    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data to find column types
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 50', conn)
    conn.close()

    # Choose columns based on plot type
    if ptype == "Scatter":
        # Scatter plots need numbers
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
        best_fit.layout.display = 'flex' # Show best-fit option
    else:
        # Bar plots can use any column
        all_cols = df.columns.tolist()
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols
        best_fit.layout.display = 'none' # Hide best-fit option

# --- 6. Plotting Helper Functions (Cleaner) ---

def plot_scatter_on_axis(ax, df, x_col, y_col, add_best_fit):
    """Creates a scatter plot on the given axes."""
    # Convert to numbers and drop missing values
    df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No numeric data to plot.", ha='center', va='center')
        return

    ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)
    ax.set_ylabel(y_col)

    # Add best-fit line if checked
    if add_best_fit and len(df) > 1:
        m, b = np.polyfit(df[x_col], df[y_col], 1)
        ax.plot(df[x_col], m * df[x_col] + b, color='red', linewidth=2)

def plot_bar_on_axis(ax, df, x_col, y_col):
    """Creates a bar plot, including the grouped 'gender vs age' style."""
    df = df.dropna(subset=[x_col, y_col])
    
    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    # If X or Y are categorical (text), create a grouped bar chart
    is_x_categorical = df[x_col].dtype == 'object'
    is_y_categorical = df[y_col].dtype == 'object'

    if is_x_categorical or is_y_categorical:
        # This is the "color for gender vs age" plot (e.g., crosstab)
        grouped = pd.crosstab(df[x_col], df[y_col])
        grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8)
        ax.set_ylabel("Count")
        ax.legend(title=y_col, loc='upper right')
    else:
        # Otherwise, just make a simple bar chart (count of X)
        df.groupby(x_col)[y_col].count().plot(kind='bar', ax=ax, color='skyblue')
        ax.set_ylabel("Count")
    
    plt.xticks(rotation=45, ha='right')


# --- 7. When user clicks "Generate Plot" (Simple Version) ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)

        # --- 1. Get User Selections ---
        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value

        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid X/Y columns.")
            return

        # --- 2. Load Full Data ---
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # --- 3. Apply Age Filter (if 'age' column exists) ---
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df.dropna(subset=['age'])
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- 4. Create Plot ---
        fig, ax = plt.subplots(figsize=(7, 4)) # A bit wider for labels

        if ptype == "Scatter":
            plot_scatter_on_axis(ax, df, x_col, y_col, best_fit.value)
        else:
            plot_bar_on_axis(ax, df, x_col, y_col)

        # --- 5. Final Plot Formatting ---
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)
        plt.tight_layout()

        # --- 6. Show Image in Output ---
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        plt.close(fig)

# --- 8. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 9. Display Interface ---
# A simple layout with two rows for controls
controls = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_dropdown, y_dropdown]),
    widgets.HBox([generate_button, best_fit])
])

# The final UI is the controls on top of the output area
ui = widgets.VBox([controls, output_area])

# Trigger the update once to populate the column dropdowns
update_columns(None)

display(ui)

VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', '…

In [3]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
# --- NEW WIDGET ---
color_dropdown = widgets.Dropdown(description="Color by:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    with output_area:
        pass # Keep UI stable
    
    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data to find column types
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 50', conn)
    conn.close()

    # Get different types of columns
    all_cols = df.columns.tolist()
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    # Get text/categorical columns for the new 'color' dropdown
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    color_options = ['None'] + categorical_cols # Add 'None' option

    if ptype == "Scatter":
        # X and Y MUST be numeric
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
        # Color dropdown gets text columns
        color_dropdown.options = color_options
        color_dropdown.layout.display = 'flex' # Show color dropdown
        best_fit.layout.display = 'flex'
    else:
        # Bar plots can use any column
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols
        color_dropdown.layout.display = 'none' # Hide color dropdown
        best_fit.layout.display = 'none'

    # Set default values to prevent errors
    if x_dropdown.options:
        x_dropdown.value = x_dropdown.options[0]
    if y_dropdown.options and len(y_dropdown.options) > 1:
        y_dropdown.value = y_dropdown.options[1]
    elif y_dropdown.options:
        y_dropdown.value = y_dropdown.options[0]
    if color_dropdown.options:
        color_dropdown.value = 'None'


# --- 6. Plotting Helper Functions (Cleaner) ---

def plot_scatter_on_axis(ax, df, x_col, y_col, color_col, add_best_fit):
    """Creates a scatter plot, now with optional coloring."""
    
    # Convert to numbers and drop missing values for X and Y
    df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No numeric data to plot.", ha='center', va='center')
        return

    # --- Color Logic ---
    if color_col and color_col != 'None' and color_col in df.columns:
        df = df.dropna(subset=[color_col]) # Drop rows where color is missing
        categories = df[color_col].unique()
        
        # Plot each category with a different color
        for category in categories:
            df_group = df[df[color_col] == category]
            ax.scatter(df_group[x_col], df_group[y_col], alpha=0.7, label=str(category))
        if len(categories) > 0:
            ax.legend(title=color_col)
    else:
        # No color selected, plot all as one color
        ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

    ax.set_ylabel(y_col)

    # --- Best-Fit Line Logic ---
    # Plot one single best-fit line for all data, regardless of color
    if add_best_fit and len(df) > 1:
        m, b = np.polyfit(df[x_col], df[y_col], 1)
        ax.plot(df[x_col], m * df[x_col] + b, color='red', linewidth=2)


def plot_bar_on_axis(ax, df, x_col, y_col):
    """Creates a bar plot, including the grouped 'gender vs age' style."""
    df = df.dropna(subset=[x_col, y_col])
    
    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    is_x_categorical = df[x_col].dtype == 'object'
    is_y_categorical = df[y_col].dtype == 'object'

    if is_x_categorical or is_y_categorical:
        # This is the "color for gender vs age" plot
        grouped = pd.crosstab(df[x_col], df[y_col])
        grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8)
        ax.set_ylabel("Count")
        ax.legend(title=y_col, loc='upper right')
    else:
        df.groupby(x_col)[y_col].count().plot(kind='bar', ax=ax, color='skyblue')
        ax.set_ylabel("Count")
    
    plt.xticks(rotation=45, ha='right')


# --- 7. When user clicks "Generate Plot" (Simple Version) ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)

        # --- 1. Get User Selections ---
        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value
        color_col = color_dropdown.value # <-- GET NEW VALUE
        
        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid X/Y columns.")
            return

        # --- 2. Load Full Data ---
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # --- 3. Apply Age Filter (if 'age' column exists) ---
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df.dropna(subset=['age'])
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- 4. Create Plot ---
        fig, ax = plt.subplots(figsize=(7, 4))

        if ptype == "Scatter":
            # --- PASS NEW VALUE TO PLOTTER ---
            plot_scatter_on_axis(ax, df, x_col, y_col, color_col, best_fit.value)
        else:
            plot_bar_on_axis(ax, df, x_col, y_col)

        # --- 5. Final Plot Formatting ---
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)
        plt.tight_layout()

        # --- 6. Show Image in Output ---
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        plt.close(fig)

# --- 8. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 9. Display Interface ---
# Add the new 'color_dropdown' to the layout
controls = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_dropdown, y_dropdown, color_dropdown]), # <-- UPDATED ROW
    widgets.HBox([generate_button, best_fit])
])

# The final UI
ui = widgets.VBox([controls, output_area])

# Trigger the update once to populate the column dropdowns
update_columns(None)

display(ui)

VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', '…

In [4]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker # <-- IMPORTED FOR TICK MARKS
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
color_dropdown = widgets.Dropdown(description="Color by:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    with output_area:
        pass # Keep UI stable
    
    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data to find column types
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 50', conn)
    conn.close()

    # Get different types of columns
    all_cols = df.columns.tolist()
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    color_options = ['None'] + categorical_cols # Add 'None' option

    if ptype == "Scatter":
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
        color_dropdown.options = color_options
        color_dropdown.layout.display = 'flex' 
        best_fit.layout.display = 'flex'
    else:
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols
        color_dropdown.layout.display = 'none'
        best_fit.layout.display = 'none'

    # Set default values to prevent errors
    if x_dropdown.options:
        x_dropdown.value = x_dropdown.options[0]
    if y_dropdown.options and len(y_dropdown.options) > 1:
        y_dropdown.value = y_dropdown.options[1]
    elif y_dropdown.options:
        y_dropdown.value = y_dropdown.options[0]
    if color_dropdown.options:
        color_dropdown.value = 'None'


# --- 6. Plotting Helper Functions (Cleaner) ---

def plot_scatter_on_axis(ax, df, x_col, y_col, color_col, add_best_fit):
    """Creates a scatter plot, now with optional coloring and integer axis ticks."""
    
    df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No numeric data to plot.", ha='center', va='center')
        return

    # --- Color Logic ---
    if color_col and color_col != 'None' and color_col in df.columns:
        df = df.dropna(subset=[color_col]) 
        categories = df[color_col].unique()
        
        for category in categories:
            df_group = df[df[color_col] == category]
            ax.scatter(df_group[x_col], df_group[y_col], alpha=0.7, label=str(category))
        if len(categories) > 0:
            ax.legend(title=color_col)
    else:
        ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

    ax.set_ylabel(y_col)

    # --- Best-Fit Line Logic ---
    if add_best_fit and len(df) > 1:
        m, b = np.polyfit(df[x_col], df[y_col], 1)
        ax.plot(df[x_col], m * df[x_col] + b, color='red', linewidth=2)
        
    # --- NEW: Set Integer Ticks for 'age' ---
    if x_col == 'age':
        ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
    if y_col == 'age':
        ax.yaxis.set_major_locator(mticker.MultipleLocator(1))


def plot_bar_on_axis(ax, df, x_col, y_col):
    """Creates a bar plot. 'age' data should be rounded before calling this."""
    df = df.dropna(subset=[x_col, y_col])
    
    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    is_x_categorical = df[x_col].dtype == 'object'
    is_y_categorical = df[y_col].dtype == 'object'
    # Check for integer type as well, since we converted 'age'
    is_x_numeric_like = pd.api.types.is_numeric_dtype(df[x_col])
    is_y_numeric_like = pd.api.types.is_numeric_dtype(df[y_col])

    if (is_x_categorical or is_y_categorical) or (x_col == 'age' and is_y_categorical):
        # This is the "color for gender vs age" plot
        grouped = pd.crosstab(df[x_col], df[y_col])
        grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8)
        ax.set_ylabel("Count")
        ax.legend(title=y_col, loc='upper right')
    else:
        df.groupby(x_col)[y_col].count().plot(kind='bar', ax=ax, color='skyblue')
        ax.set_ylabel("Count")
    
    plt.xticks(rotation=45, ha='right')


# --- 7. When user clicks "Generate Plot" (Simple Version) ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)

        # --- 1. Get User Selections ---
        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value
        color_col = color_dropdown.value
        
        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid X/Y columns.")
            return

        # --- 2. Load Full Data ---
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # --- 3. Apply Age Filter (if 'age' column exists) ---
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df.dropna(subset=['age'])
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- 4. Create Plot ---
        fig, ax = plt.subplots(figsize=(7, 4))

        if ptype == "Scatter":
            plot_scatter_on_axis(ax, df, x_col, y_col, color_col, best_fit.value)
        else:
            # --- NEW: Round ages for Bar Charts ---
            if 'age' in df.columns:
                df['age'] = df['age'].round().astype(int)
            plot_bar_on_axis(ax, df, x_col, y_col)

        # --- 5. Final Plot Formatting ---
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)
        plt.tight_layout()

        # --- 6. Show Image in Output ---
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        plt.close(fig)

# --- 8. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 9. Display Interface ---
controls = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_dropdown, y_dropdown, color_dropdown]),
    widgets.HBox([generate_button, best_fit])
])

ui = widgets.VBox([controls, output_area])

# Trigger the update once to populate the column dropdowns
update_columns(None)

display(ui)

VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', '…

In [5]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker # For integer axis ticks
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
color_dropdown = widgets.Dropdown(description="Color by:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    with output_area:
        pass # Keep UI stable
    
    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data to find column types
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 50', conn)
    conn.close()

    # Get different types of columns
    all_cols = df.columns.tolist()
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    color_options = ['None'] + categorical_cols # Add 'None' option

    if ptype == "Scatter":
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
        color_dropdown.options = color_options
        color_dropdown.layout.display = 'flex' 
        best_fit.layout.display = 'flex'
    else:
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols
        color_dropdown.layout.display = 'none'
        best_fit.layout.display = 'none'

    # Set default values to prevent errors
    if x_dropdown.options:
        x_dropdown.value = x_dropdown.options[0]
    if y_dropdown.options and len(y_dropdown.options) > 1:
        y_dropdown.value = y_dropdown.options[1]
    elif y_dropdown.options:
        y_dropdown.value = y_dropdown.options[0]
    if color_dropdown.options:
        color_dropdown.value = 'None'


# --- 6. Plotting Helper Functions (Cleaner) ---

def plot_scatter_on_axis(ax, df, x_col, y_col, color_col, add_best_fit):
    """Creates a scatter plot, now with optional coloring and integer axis ticks."""
    
    df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No numeric data to plot.", ha='center', va='center')
        return

    # --- Color Logic ---
    if color_col and color_col != 'None' and color_col in df.columns:
        df = df.dropna(subset=[color_col]) 
        categories = df[color_col].unique()
        
        for category in categories:
            df_group = df[df[color_col] == category]
            ax.scatter(df_group[x_col], df_group[y_col], alpha=0.7, label=str(category))
        if len(categories) > 0:
            ax.legend(title=color_col)
    else:
        ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

    ax.set_ylabel(y_col)

    # --- Best-Fit Line Logic ---
    if add_best_fit and len(df) > 1:
        m, b = np.polyfit(df[x_col], df[y_col], 1)
        ax.plot(df[x_col], m * df[x_col] + b, color='red', linewidth=2)
        
    # --- Set Integer Ticks for 'age' ---
    if x_col == 'age':
        ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
    if y_col == 'age':
        ax.yaxis.set_major_locator(mticker.MultipleLocator(1))


def plot_bar_on_axis(ax, df, x_col, y_col):
    """Creates a bar plot. 'age' data should be rounded before calling this."""
    df = df.dropna(subset=[x_col, y_col])
    
    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    is_x_categorical = df[x_col].dtype == 'object'
    is_y_categorical = df[y_col].dtype == 'object'
    is_x_numeric_like = pd.api.types.is_numeric_dtype(df[x_col])
    is_y_numeric_like = pd.api.types.is_numeric_dtype(df[y_col])

    if (is_x_categorical or is_y_categorical) or (x_col == 'age' and is_y_categorical):
        # This is the "color for gender vs age" plot
        grouped = pd.crosstab(df[x_col], df[y_col])
        grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8)
        ax.set_ylabel("Count")
        ax.legend(title=y_col, loc='upper right')
    else:
        df.groupby(x_col)[y_col].count().plot(kind='bar', ax=ax, color='skyblue')
        ax.set_ylabel("Count")
    
    # Tick rotation is now handled in the main generate_plot function


# --- 7. When user clicks "Generate Plot" (Simple Version) ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)

        # --- 1. Get User Selections ---
        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value
        color_col = color_dropdown.value
        
        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid X/Y columns.")
            return

        # --- 2. Load Full Data ---
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # --- 3. Apply Age Filter (if 'age' column exists) ---
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df.dropna(subset=['age'])
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- 4. Create Plot ---
        fig, ax = plt.subplots(figsize=(7, 4))

        if ptype == "Scatter":
            plot_scatter_on_axis(ax, df, x_col, y_col, color_col, best_fit.value)
        else:
            if 'age' in df.columns:
                df['age'] = df['age'].round().astype(int)
            plot_bar_on_axis(ax, df, x_col, y_col)

        # --- 5. Final Plot Formatting ---
        
        # --- Move X-axis to the top ---
        ax.xaxis.tick_top()
        ax.xaxis.set_label_position('top')
        
        # --- Ensure Y-axis is on the left (it is by default) ---
        ax.yaxis.tick_left()
        ax.yaxis.set_label_position('left')
        
        # Set labels
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)

        # --- Adjust tick rotation for top axis ---
        if ptype == "Bar":
            # ha='left' aligns the start of the label with the tick
            plt.xticks(rotation=45, ha='left') 
        
        plt.tight_layout()

        # --- 6. Show Image in Output ---
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        plt.close(fig)

# --- 8. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 9. Display Interface ---
controls = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    # --- THIS IS THE ONLY LINE THAT CHANGED ---
    widgets.HBox([y_dropdown, x_dropdown, color_dropdown]),
    widgets.HBox([generate_button, best_fit])
])

ui = widgets.VBox([controls, output_area])

# Trigger the update once to populate the column dropdowns
update_columns(None)

display(ui)

VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', '…

In [6]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker # For integer axis ticks
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
color_dropdown = widgets.Dropdown(description="Color by:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    with output_area:
        pass # Keep UI stable
    
    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data to find column types
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 50', conn)
    conn.close()

    # Get different types of columns
    all_cols = df.columns.tolist()
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    color_options = ['None'] + categorical_cols # Add 'None' option

    if ptype == "Scatter":
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
        color_dropdown.options = color_options
        color_dropdown.layout.display = 'flex' 
        best_fit.layout.display = 'flex'
    else:
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols
        color_dropdown.layout.display = 'none'
        best_fit.layout.display = 'none'

    # Set default values to prevent errors
    if x_dropdown.options:
        x_dropdown.value = x_dropdown.options[0]
    if y_dropdown.options and len(y_dropdown.options) > 1:
        y_dropdown.value = y_dropdown.options[1]
    elif y_dropdown.options:
        y_dropdown.value = y_dropdown.options[0]
    if color_dropdown.options:
        color_dropdown.value = 'None'


# --- 6. Plotting Helper Functions (Cleaner) ---

def plot_scatter_on_axis(ax, df, x_col, y_col, color_col, add_best_fit):
    """Creates a scatter plot, now with optional coloring and integer axis ticks."""
    
    df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No numeric data to plot.", ha='center', va='center')
        return

    # --- Color Logic ---
    if color_col and color_col != 'None' and color_col in df.columns:
        df = df.dropna(subset=[color_col]) 
        categories = df[color_col].unique()
        
        for category in categories:
            df_group = df[df[color_col] == category]
            ax.scatter(df_group[x_col], df_group[y_col], alpha=0.7, label=str(category))
        if len(categories) > 0:
            ax.legend(title=color_col)
    else:
        ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

    ax.set_ylabel(y_col)

    # --- Best-Fit Line Logic ---
    if add_best_fit and len(df) > 1:
        m, b = np.polyfit(df[x_col], df[y_col], 1)
        ax.plot(df[x_col], m * df[x_col] + b, color='red', linewidth=2)
        
    # --- Set Integer Ticks for 'age' ---
    if x_col == 'age':
        ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
    if y_col == 'age':
        ax.yaxis.set_major_locator(mticker.MultipleLocator(1))


def plot_bar_on_axis(ax, df, x_col, y_col):
    """Creates a bar plot. 'age' data should be rounded before calling this."""
    df = df.dropna(subset=[x_col, y_col])
    
    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    is_x_categorical = df[x_col].dtype == 'object'
    is_y_categorical = df[y_col].dtype == 'object'
    is_x_numeric_like = pd.api.types.is_numeric_dtype(df[x_col])
    is_y_numeric_like = pd.api.types.is_numeric_dtype(df[y_col])

    if (is_x_categorical or is_y_categorical) or (x_col == 'age' and is_y_categorical):
        # This is the "color for gender vs age" plot
        grouped = pd.crosstab(df[x_col], df[y_col])
        grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8)
        ax.set_ylabel("Count")
        ax.legend(title=y_col, loc='upper right')
    else:
        df.groupby(x_col)[y_col].count().plot(kind='bar', ax=ax, color='skyblue')
        ax.set_ylabel("Count")
    
    # Tick rotation is now handled in the main generate_plot function


# --- 7. When user clicks "Generate Plot" (Simple Version) ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)

        # --- 1. Get User Selections ---
        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value
        color_col = color_dropdown.value
        
        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid X/Y columns.")
            return

        # --- 2. Load Full Data ---
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # --- 3. Apply Age Filter (if 'age' column exists) ---
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df.dropna(subset=['age'])
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- 4. Create Plot ---
        fig, ax = plt.subplots(figsize=(7, 4))

        if ptype == "Scatter":
            plot_scatter_on_axis(ax, df, x_col, y_col, color_col, best_fit.value)
        else:
            if 'age' in df.columns:
                df['age'] = df['age'].round().astype(int)
            plot_bar_on_axis(ax, df, x_col, y_col)

        # --- 5. Final Plot Formatting ---
        
        # --- RESET: X-axis on bottom, Y-axis on left (default) ---
        ax.xaxis.tick_bottom()
        ax.xaxis.set_label_position('bottom')
        ax.yaxis.tick_left()
        ax.yaxis.set_label_position('left')
        
        # Set labels
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)

        # --- RESET: Adjust tick rotation for bottom axis ---
        if ptype == "Bar":
            # ha='right' aligns the end of the label with the tick
            plt.xticks(rotation=45, ha='right') 
        
        plt.tight_layout()

        # --- 6. Show Image in Output ---
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        plt.close(fig)

# --- 8. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 9. Display Interface ---
controls = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    # Swapped dropdowns as requested
    widgets.HBox([y_dropdown, x_dropdown, color_dropdown]),
    widgets.HBox([generate_button, best_fit])
])

ui = widgets.VBox([controls, output_area])

# Trigger the update once to populate the column dropdowns
update_columns(None)

display(ui)

VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', '…

In [7]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker # For integer axis ticks
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
color_dropdown = widgets.Dropdown(description="Color by:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    with output_area:
        pass # Keep UI stable
    
    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data to find column types
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 50', conn)
    conn.close()

    # Get different types of columns
    all_cols = df.columns.tolist()
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    color_options = ['None'] + categorical_cols # Add 'None' option

    if ptype == "Scatter":
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
        color_dropdown.options = color_options
        color_dropdown.layout.display = 'flex' 
        best_fit.layout.display = 'flex'
    else:
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols
        color_dropdown.layout.display = 'none'
        best_fit.layout.display = 'none'

    # Set default values to prevent errors
    if x_dropdown.options:
        x_dropdown.value = x_dropdown.options[0]
    if y_dropdown.options and len(y_dropdown.options) > 1:
        y_dropdown.value = y_dropdown.options[1]
    elif y_dropdown.options:
        y_dropdown.value = y_dropdown.options[0]
    if color_dropdown.options:
        color_dropdown.value = 'None'


# --- 6. Plotting Helper Functions (Cleaner) ---

def plot_scatter_on_axis(ax, df, x_col, y_col, color_col, add_best_fit):
    """Creates a scatter plot, now with optional coloring and integer axis ticks."""
    
    df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No numeric data to plot.", ha='center', va='center')
        return

    # --- Color Logic ---
    if color_col and color_col != 'None' and color_col in df.columns:
        df = df.dropna(subset=[color_col]) 
        categories = df[color_col].unique()
        
        for category in categories:
            df_group = df[df[color_col] == category]
            ax.scatter(df_group[x_col], df_group[y_col], alpha=0.7, label=str(category))
        if len(categories) > 0:
            ax.legend(title=color_col)
    else:
        ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

    ax.set_ylabel(y_col)

    # --- Best-Fit Line Logic ---
    if add_best_fit and len(df) > 1:
        m, b = np.polyfit(df[x_col], df[y_col], 1)
        ax.plot(df[x_col], m * df[x_col] + b, color='red', linewidth=2)
        
    # --- Set Integer Ticks for 'age' ---
    if x_col == 'age':
        ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
    if y_col == 'age':
        ax.yaxis.set_major_locator(mticker.MultipleLocator(1))


def plot_bar_on_axis(ax, df, x_col, y_col):
    """Creates a bar plot. 'age' data should be rounded before calling this."""
    df = df.dropna(subset=[x_col, y_col])
    
    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    is_x_categorical = df[x_col].dtype == 'object'
    is_y_categorical = df[y_col].dtype == 'object'
    is_x_numeric_like = pd.api.types.is_numeric_dtype(df[x_col])
    is_y_numeric_like = pd.api.types.is_numeric_dtype(df[y_col])

    if (is_x_categorical or is_y_categorical) or (x_col == 'age' and is_y_categorical):
        # This is the "color for gender vs age" plot
        grouped = pd.crosstab(df[x_col], df[y_col])
        grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8)
        ax.set_ylabel("Count")
        ax.legend(title=y_col, loc='upper right')
    else:
        df.groupby(x_col)[y_col].count().plot(kind='bar', ax=ax, color='skyblue')
        ax.set_ylabel("Count")
    
    # Tick rotation is now handled in the main generate_plot function


# --- 7. When user clicks "Generate Plot" (Simple Version) ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)
        
        # --- THIS IS THE FIX ---
        # Close all old, stray plots to prevent the black bar
        plt.close('all') 
        # ---------------------

        # --- 1. Get User Selections ---
        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value
        color_col = color_dropdown.value
        
        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid X/Y columns.")
            return

        # --- 2. Load Full Data ---
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # --- 3. Apply Age Filter (if 'age' column exists) ---
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df.dropna(subset=['age'])
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- 4. Create Plot ---
        fig, ax = plt.subplots(figsize=(7, 4))

        if ptype == "Scatter":
            plot_scatter_on_axis(ax, df, x_col, y_col, color_col, best_fit.value)
        else:
            if 'age' in df.columns:
                df['age'] = df['age'].round().astype(int)
            plot_bar_on_axis(ax, df, x_col, y_col)

        # --- 5. Final Plot Formatting ---
        ax.xaxis.tick_bottom()
        ax.xaxis.set_label_position('bottom')
        ax.yaxis.tick_left()
        ax.yaxis.set_label_position('left')
        
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)

        if ptype == "Bar":
            plt.xticks(rotation=45, ha='right') 
        
        plt.tight_layout()

        # --- 6. Show Image in Output ---
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        
        # We still close the *current* fig, just to be tidy
        plt.close(fig)

# --- 8. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 9. Display Interface ---
controls = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    # Swapped dropdowns as requested
    widgets.HBox([y_dropdown, x_dropdown, color_dropdown]),
    widgets.HBox([generate_button, best_fit])
])

ui = widgets.VBox([controls, output_area])

# Trigger the update once to populate the column dropdowns
update_columns(None)

display(ui)

VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', '…

In [8]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker # For integer axis ticks
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
color_dropdown = widgets.Dropdown(description="Color by:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    with output_area:
        pass # Keep UI stable
    
    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data to find column types
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 50', conn)
    conn.close()

    # Get different types of columns
    all_cols = df.columns.tolist()
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    color_options = ['None'] + categorical_cols # Add 'None' option

    if ptype == "Scatter":
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
        color_dropdown.options = color_options
        color_dropdown.layout.display = 'flex' 
        best_fit.layout.display = 'flex'
    else:
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols
        color_dropdown.layout.display = 'none'
        best_fit.layout.display = 'none'

    # Set default values to prevent errors
    if x_dropdown.options:
        x_dropdown.value = x_dropdown.options[0]
    if y_dropdown.options and len(y_dropdown.options) > 1:
        y_dropdown.value = y_dropdown.options[1]
    elif y_dropdown.options:
        y_dropdown.value = y_dropdown.options[0]
    if color_dropdown.options:
        color_dropdown.value = 'None'


# --- 6. Plotting Helper Functions (Cleaner) ---

def plot_scatter_on_axis(ax, df, x_col, y_col, color_col, add_best_fit):
    """Creates a scatter plot, now with optional coloring and integer axis ticks."""
    
    df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No numeric data to plot.", ha='center', va='center')
        return

    # --- Color Logic ---
    if color_col and color_col != 'None' and color_col in df.columns:
        df = df.dropna(subset=[color_col]) 
        categories = df[color_col].unique()
        
        for category in categories:
            df_group = df[df[color_col] == category]
            ax.scatter(df_group[x_col], df_group[y_col], alpha=0.7, label=str(category))
        if len(categories) > 0:
            ax.legend(title=color_col)
    else:
        ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

    ax.set_ylabel(y_col)

    # --- Best-Fit Line Logic ---
    if add_best_fit and len(df) > 1:
        m, b = np.polyfit(df[x_col], df[y_col], 1)
        ax.plot(df[x_col], m * df[x_col] + b, color='red', linewidth=2)
        
    # --- Set Integer Ticks for 'age' ---
    if x_col == 'age':
        ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
    if y_col == 'age':
        ax.yaxis.set_major_locator(mticker.MultipleLocator(1))


def plot_bar_on_axis(ax, df, x_col, y_col):
    """Creates a bar plot. 'age' data should be rounded before calling this."""
    df = df.dropna(subset=[x_col, y_col])
    
    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    is_x_categorical = df[x_col].dtype == 'object'
    is_y_categorical = df[y_col].dtype == 'object'

    if (is_x_categorical or is_y_categorical) or (x_col == 'age' and is_y_categorical):
        # This is the "color for gender vs age" plot
        grouped = pd.crosstab(df[x_col], df[y_col])
        grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8) # Keep this for grouped
        ax.set_ylabel("Count")
        ax.legend(title=y_col, loc='upper right')
        # Apply rotation settings directly to the axis
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    else:
        # --- THIS IS THE FIX ---
        # Replaced df.groupby(...).plot() with ax.bar()
        grouped_data = df.groupby(x_col)[y_col].count()
        # Ensure index is string for categorical plotting
        ax.bar(grouped_data.index.astype(str), grouped_data.values, color='skyblue')
        ax.set_ylabel("Count")
        # Apply rotation settings directly to the axis
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
        # --- END FIX ---


# --- 7. When user clicks "Generate Plot" (Simple Version) ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)
        
        # --- Close all old, stray plots to prevent the black bar ---
        plt.close('all') 

        # --- 1. Get User Selections ---
        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value
        color_col = color_dropdown.value
        
        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid X/Y columns.")
            return

        # --- 2. Load Full Data ---
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # --- 3. Apply Age Filter (if 'age' column exists) ---
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df.dropna(subset=['age'])
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- 4. Create Plot ---
        fig, ax = plt.subplots(figsize=(7, 4))

        if ptype == "Scatter":
            plot_scatter_on_axis(ax, df, x_col, y_col, color_col, best_fit.value)
        else:
            if 'age' in df.columns:
                df['age'] = df['age'].round().astype(int)
            plot_bar_on_axis(ax, df, x_col, y_col)

        # --- 5. Final Plot Formatting ---
        ax.xaxis.tick_bottom()
        ax.xaxis.set_label_position('bottom')
        ax.yaxis.tick_left()
        ax.yaxis.set_label_position('left')
        
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)

        # --- The rotation logic was moved to plot_bar_on_axis ---
        
        plt.tight_layout()

        # --- 6. Show Image in Output ---
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        
        # We still close the *current* fig, just to be tidy
        plt.close(fig)

# --- 8. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 9. Display Interface ---
controls = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    # Swapped dropdowns as requested
    widgets.HBox([y_dropdown, x_dropdown, color_dropdown]),
    widgets.HBox([generate_button, best_fit])
])

ui = widgets.VBox([controls, output_area])

# Trigger the update once to populate the column dropdowns
update_columns(None)

display(ui)

VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', '…

In [9]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker # For integer axis ticks
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
color_dropdown = widgets.Dropdown(description="Color by:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    with output_area:
        pass # Keep UI stable
    
    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data to find column types
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 50', conn)
    conn.close()

    # Get different types of columns
    all_cols = df.columns.tolist()
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    color_options = ['None'] + categorical_cols # Add 'None' option

    if ptype == "Scatter":
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
        color_dropdown.options = color_options
        color_dropdown.layout.display = 'flex' 
        best_fit.layout.display = 'flex'
    else:
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols
        color_dropdown.layout.display = 'none'
        best_fit.layout.display = 'none'

    # Set default values to prevent errors
    if x_dropdown.options:
        x_dropdown.value = x_dropdown.options[0]
    if y_dropdown.options and len(y_dropdown.options) > 1:
        y_dropdown.value = y_dropdown.options[1]
    elif y_dropdown.options:
        y_dropdown.value = y_dropdown.options[0]
    if color_dropdown.options:
        color_dropdown.value = 'None'


# --- 6. Plotting Helper Functions (Cleaner) ---

def plot_scatter_on_axis(ax, df, x_col, y_col, color_col, add_best_fit):
    """Creates a scatter plot, now with optional coloring and integer axis ticks."""
    
    df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No numeric data to plot.", ha='center', va='center')
        return

    # --- Color Logic ---
    if color_col and color_col != 'None' and color_col in df.columns:
        df = df.dropna(subset=[color_col]) 
        categories = df[color_col].unique()
        
        for category in categories:
            df_group = df[df[color_col] == category]
            ax.scatter(df_group[x_col], df_group[y_col], alpha=0.7, label=str(category))
        if len(categories) > 0:
            ax.legend(title=color_col)
    else:
        ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

    ax.set_ylabel(y_col)

    # --- Best-Fit Line Logic ---
    if add_best_fit and len(df) > 1:
        m, b = np.polyfit(df[x_col], df[y_col], 1)
        ax.plot(df[x_col], m * df[x_col] + b, color='red', linewidth=2)
        
    # --- Set Integer Ticks for 'age' ---
    if x_col == 'age':
        ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
    if y_col == 'age':
        ax.yaxis.set_major_locator(mticker.MultipleLocator(1))


# --- THIS ENTIRE FUNCTION IS REWRITTEN TO FIX THE BLACK BAR ---
def plot_bar_on_axis(ax, df, x_col, y_col):
    """Creates a bar plot. 'age' data should be rounded before calling this."""
    df = df.dropna(subset=[x_col, y_col])
    
    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    is_x_categorical = df[x_col].dtype == 'object'
    is_y_categorical = df[y_col].dtype == 'object'

    if (is_x_categorical or is_y_categorical) or (x_col == 'age' and is_y_categorical):
        # --- NEW: Re-write of "color for gender vs age" plot ---
        # This avoids using pandas.plot() which causes the black bar
        grouped = pd.crosstab(df[x_col], df[y_col])
        
        n_groups = len(grouped)
        n_categories = len(grouped.columns)
        
        # Calculate bar width and positions
        total_width_per_group = 0.8
        bar_width = total_width_per_group / n_categories
        x_indices = np.arange(n_groups) # The center position for each group
        
        # Loop for each category (e.g., 'Male', 'Female') and plot its bars
        for i, category in enumerate(grouped.columns):
            # Calculate the position for this specific bar
            offset = (i - (n_categories - 1) / 2) * bar_width
            bar_positions = x_indices + offset
            
            counts = grouped[category]
            ax.bar(bar_positions, counts, width=bar_width, label=category, alpha=0.8)

        ax.set_ylabel("Count")
        ax.legend(title=y_col, loc='upper right')
        
        # Set the x-ticks to be in the center of the groups
        ax.set_xticks(x_indices)
        ax.set_xticklabels(grouped.index, rotation=45, ha='right')
        # --- END NEW ---

    else:
        # This part is already fixed (uses ax.bar)
        grouped_data = df.groupby(x_col)[y_col].count()
        # Ensure index is string for categorical plotting
        ax.bar(grouped_data.index.astype(str), grouped_data.values, color='skyblue')
        ax.set_ylabel("Count")
        # Apply rotation settings directly to the axis
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')


# --- 7. When user clicks "Generate Plot" (Simple Version) ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)
        
        # --- Close all old, stray plots to prevent the black bar ---
        plt.close('all') 

        # --- 1. Get User Selections ---
        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value
        color_col = color_dropdown.value
        
        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid X/Y columns.")
            return

        # --- 2. Load Full Data ---
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # --- 3. Apply Age Filter (if 'age' column exists) ---
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df.dropna(subset=['age'])
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- 4. Create Plot ---
        fig, ax = plt.subplots(figsize=(7, 4))

        if ptype == "Scatter":
            plot_scatter_on_axis(ax, df, x_col, y_col, color_col, best_fit.value)
        else:
            if 'age' in df.columns:
                df['age'] = df['age'].round().astype(int)
            plot_bar_on_axis(ax, df, x_col, y_col)

        # --- 5. Final Plot Formatting ---
        ax.xaxis.tick_bottom()
        ax.xaxis.set_label_position('bottom')
        ax.yaxis.tick_left()
        ax.yaxis.set_label_position('left')
        
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)

        # --- The rotation logic was moved to plot_bar_on_axis ---
        
        plt.tight_layout()

        # --- 6. Show Image in Output ---
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        
        # We still close the *current* fig, just to be tidy
        plt.close(fig)

# --- 8. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 9. Display Interface ---
controls = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    # Swapped dropdowns as requested
    widgets.HBox([y_dropdown, x_dropdown, color_dropdown]),
    widgets.HBox([generate_button, best_fit])
])

ui = widgets.VBox([controls, output_area])

# Trigger the update once to populate the column dropdowns
update_columns(None)

display(ui)

VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', '…

In [10]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker # For integer axis ticks
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
color_dropdown = widgets.Dropdown(description="Color by:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    with output_area:
        pass # Keep UI stable
    
    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data to find column types
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 50', conn)
    conn.close()

    # Get different types of columns
    all_cols = df.columns.tolist()
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    color_options = ['None'] + categorical_cols # Add 'None' option

    if ptype == "Scatter":
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
        color_dropdown.options = color_options
        color_dropdown.layout.display = 'flex' 
        best_fit.layout.display = 'flex'
    else:
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols
        color_dropdown.layout.display = 'none'
        best_fit.layout.display = 'none'

    # Set default values to prevent errors
    if x_dropdown.options:
        x_dropdown.value = x_dropdown.options[0]
    if y_dropdown.options and len(y_dropdown.options) > 1:
        y_dropdown.value = y_dropdown.options[1]
    elif y_dropdown.options:
        y_dropdown.value = y_dropdown.options[0]
    if color_dropdown.options:
        color_dropdown.value = 'None'


# --- 6. Plotting Helper Functions (Cleaner) ---

def plot_scatter_on_axis(ax, df, x_col, y_col, color_col, add_best_fit):
    """Creates a scatter plot, now with optional coloring and integer axis ticks."""
    
    df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No numeric data to plot.", ha='center', va='center')
        return

    # --- Color Logic ---
    if color_col and color_col != 'None' and color_col in df.columns:
        df = df.dropna(subset=[color_col]) 
        categories = df[color_col].unique()
        
        for category in categories:
            df_group = df[df[color_col] == category]
            ax.scatter(df_group[x_col], df_group[y_col], alpha=0.7, label=str(category))
        if len(categories) > 0:
            ax.legend(title=color_col)
    else:
        ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

    ax.set_ylabel(y_col)

    # --- Best-Fit Line Logic ---
    if add_best_fit and len(df) > 1:
        m, b = np.polyfit(df[x_col], df[y_col], 1)
        ax.plot(df[x_col], m * df[x_col] + b, color='red', linewidth=2)
        
    # --- Set Integer Ticks for 'age' ---
    if x_col == 'age':
        ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
    if y_col == 'age':
        ax.yaxis.set_major_locator(mticker.MultipleLocator(1))


def plot_bar_on_axis(ax, df, x_col, y_col):
    """Creates a bar plot. 'age' data should be rounded before calling this."""
    df = df.dropna(subset=[x_col, y_col])
    
    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    is_x_categorical = df[x_col].dtype == 'object'
    is_y_categorical = df[y_col].dtype == 'object'

    # --- THIS IS THE FIX ---
    # Only create a grouped chart if BOTH are categorical (e.g., gender vs. major)
    if is_x_categorical and is_y_categorical:
    # --- END FIX ---
        # --- "color for gender vs major" plot ---
        grouped = pd.crosstab(df[x_col], df[y_col])
        
        n_groups = len(grouped)
        n_categories = len(grouped.columns)
        
        # Calculate bar width and positions
        total_width_per_group = 0.8
        bar_width = total_width_per_group / n_categories
        x_indices = np.arange(n_groups) # The center position for each group
        
        # Loop for each category (e.g., 'Male', 'Female') and plot its bars
        for i, category in enumerate(grouped.columns):
            # Calculate the position for this specific bar
            offset = (i - (n_categories - 1) / 2) * bar_width
            bar_positions = x_indices + offset
            
            counts = grouped[category]
            ax.bar(bar_positions, counts, width=bar_width, label=category, alpha=0.8)

        ax.set_ylabel("Count")
        ax.legend(title=y_col, loc='upper right')
        
        # Set the x-ticks to be in the center of the groups
        ax.set_xticks(x_indices)
        ax.set_xticklabels(grouped.index, rotation=45, ha='right')
        # --- END NEW ---

    else:
        # --- This is the simple bar chart (e.g., "gender vs. age") ---
        # It will group by the X-column and count the Y-column entries.
        grouped_data = df.groupby(x_col)[y_col].count()
        # Ensure index is string for categorical plotting
        ax.bar(grouped_data.index.astype(str), grouped_data.values, color='skyblue')
        ax.set_ylabel("Count")

In [11]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker # For integer axis ticks
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
color_dropdown = widgets.Dropdown(description="Color by:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    with output_area:
        pass # Keep UI stable
    
    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data to find column types
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 50', conn)
    conn.close()

    # Get different types of columns
    all_cols = df.columns.tolist()
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    color_options = ['None'] + categorical_cols # Add 'None' option

    if ptype == "Scatter":
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
        color_dropdown.options = color_options
        color_dropdown.layout.display = 'flex' 
        best_fit.layout.display = 'flex'
    else:
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols
        color_dropdown.layout.display = 'none'
        best_fit.layout.display = 'none'

    # Set default values to prevent errors
    if x_dropdown.options:
        x_dropdown.value = x_dropdown.options[0]
    if y_dropdown.options and len(y_dropdown.options) > 1:
        y_dropdown.value = y_dropdown.options[1]
    elif y_dropdown.options:
        y_dropdown.value = y_dropdown.options[0]
    if color_dropdown.options:
        color_dropdown.value = 'None'


# --- 6. Plotting Helper Functions (Cleaner) ---

def plot_scatter_on_axis(ax, df, x_col, y_col, color_col, add_best_fit):
    """Creates a scatter plot, now with optional coloring and integer axis ticks."""
    
    df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No numeric data to plot.", ha='center', va='center')
        return

    # --- Color Logic ---
    if color_col and color_col != 'None' and color_col in df.columns:
        df = df.dropna(subset=[color_col]) 
        categories = df[color_col].unique()
        
        for category in categories:
            df_group = df[df[color_col] == category]
            ax.scatter(df_group[x_col], df_group[y_col], alpha=0.7, label=str(category))
        if len(categories) > 0:
            ax.legend(title=color_col)
    else:
        ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

    ax.set_ylabel(y_col)

    # --- Best-Fit Line Logic ---
    if add_best_fit and len(df) > 1:
        m, b = np.polyfit(df[x_col], df[y_col], 1)
        ax.plot(df[x_col], m * df[x_col] + b, color='red', linewidth=2)
        
    # --- Set Integer Ticks for 'age' ---
    if x_col == 'age':
        ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
    if y_col == 'age':
        ax.yaxis.set_major_locator(mticker.MultipleLocator(1))


def plot_bar_on_axis(ax, df, x_col, y_col):
    """Creates a bar plot. 'age' data should be rounded before calling this."""
    df = df.dropna(subset=[x_col, y_col])
    
    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    is_x_categorical = df[x_col].dtype == 'object'
    is_y_categorical = df[y_col].dtype == 'object'

    # --- THIS IS THE FIX ---
    # Only create a grouped chart if BOTH are categorical (e.g., gender vs. major)
    if is_x_categorical and is_y_categorical:
    # --- END FIX ---
        # --- "color for gender vs major" plot ---
        grouped = pd.crosstab(df[x_col], df[y_col])
        
        n_groups = len(grouped)
        n_categories = len(grouped.columns)
        
        # Calculate bar width and positions
        total_width_per_group = 0.8
        bar_width = total_width_per_group / n_categories
        x_indices = np.arange(n_groups) # The center position for each group
        
        # Loop for each category (e.g., 'Male', 'Female') and plot its bars
        for i, category in enumerate(grouped.columns):
            # Calculate the position for this specific bar
            offset = (i - (n_categories - 1) / 2) * bar_width
            bar_positions = x_indices + offset
            
            counts = grouped[category]
            ax.bar(bar_positions, counts, width=bar_width, label=category, alpha=0.8)

        ax.set_ylabel("Count")
        ax.legend(title=y_col, loc='upper right')
        
        # Set the x-ticks to be in the center of the groups
        ax.set_xticks(x_indices)
        ax.set_xticklabels(grouped.index, rotation=45, ha='right')
        # --- END NEW ---

    else:
        # --- This is the simple bar chart (e.g., "gender vs. age") ---
        # It will group by the X-column and count the Y-column entries.
        grouped_data = df.groupby(x_col)[y_col].count()
        # Ensure index is string for categorical plotting
        ax.bar(grouped_data.index.astype(str), grouped_data.values, color='skyblue')
        ax.set_ylabel("Count")
        # Apply rotation settings directly to the axis
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')


# --- 7. When user clicks "Generate Plot" (Simple Version) ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)
        
        # --- Close all old, stray plots to prevent the black bar ---
        plt.close('all') 

        # --- 1. Get User Selections ---
        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value
        color_col = color_dropdown.value
        
        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid X/Y columns.")
            return

        # --- 2. Load Full Data ---
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # --- 3. Apply Age Filter (if 'age' column exists) ---
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df.dropna(subset=['age'])
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- 4. Create Plot ---
        fig, ax = plt.subplots(figsize=(7, 4))

        if ptype == "Scatter":
            plot_scatter_on_axis(ax, df, x_col, y_col, color_col, best_fit.value)
        else:
            if 'age' in df.columns:
                df['age'] = df['age'].round().astype(int)
            plot_bar_on_axis(ax, df, x_col, y_col)

        # --- 5. Final Plot Formatting ---
        ax.xaxis.tick_bottom()
        ax.xaxis.set_label_position('bottom')
        ax.yaxis.tick_left()
        ax.yaxis.set_label_position('left')
        
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)
        
        plt.tight_layout()

        # --- 6. Show Image in Output ---
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        
        # We still close the *current* fig, just to be tidy
        plt.close(fig)

# --- 8. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 9. Display Interface ---
controls = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    # Swapped dropdowns as requested
    widgets.HBox([y_dropdown, x_dropdown, color_dropdown]),
    widgets.HBox([generate_button, best_fit])
])

ui = widgets.VBox([controls, output_area])

# Trigger the update once to populate the column dropdowns
update_columns(None)

display(ui)

VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', '…

In [12]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker # For integer axis ticks
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
color_dropdown = widgets.Dropdown(description="Color by:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    with output_area:
        pass # Keep UI stable
    
    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data to find column types
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 50', conn)
    conn.close()

    # Get different types of columns
    all_cols = df.columns.tolist()
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    color_options = ['None'] + categorical_cols # Add 'None' option

    if ptype == "Scatter":
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
        color_dropdown.options = color_options
        color_dropdown.layout.display = 'flex' 
        best_fit.layout.display = 'flex'
    else:
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols
        color_dropdown.layout.display = 'none'
        best_fit.layout.display = 'none'

    # Set default values to prevent errors
    if x_dropdown.options:
        x_dropdown.value = x_dropdown.options[0]
    if y_dropdown.options and len(y_dropdown.options) > 1:
        y_dropdown.value = y_dropdown.options[1]
    elif y_dropdown.options:
        y_dropdown.value = y_dropdown.options[0]
    if color_dropdown.options:
        color_dropdown.value = 'None'


# --- 6. Plotting Helper Functions (Cleaner) ---

def plot_scatter_on_axis(ax, df, x_col, y_col, color_col, add_best_fit):
    """Creates a scatter plot, now with optional coloring and integer axis ticks."""
    
    df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No numeric data to plot.", ha='center', va='center')
        return

    # --- Color Logic ---
    if color_col and color_col != 'None' and color_col in df.columns:
        df = df.dropna(subset=[color_col]) 
        categories = df[color_col].unique()
        
        for category in categories:
            df_group = df[df[color_col] == category]
            ax.scatter(df_group[x_col], df_group[y_col], alpha=0.7, label=str(category))
        if len(categories) > 0:
            ax.legend(title=color_col)
    else:
        ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

    ax.set_ylabel(y_col)

    # --- Best-Fit Line Logic ---
    if add_best_fit and len(df) > 1:
        m, b = np.polyfit(df[x_col], df[y_col], 1)
        ax.plot(df[x_col], m * df[x_col] + b, color='red', linewidth=2)
        
    # --- Set Integer Ticks for 'age' ---
    if x_col == 'age':
        ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
    if y_col == 'age':
        ax.yaxis.set_major_locator(mticker.MultipleLocator(1))


def plot_bar_on_axis(ax, df, x_col, y_col):
    """Creates a bar plot. 'age' data should be rounded before calling this."""
    df = df.dropna(subset=[x_col, y_col])
    
    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    is_x_categorical = df[x_col].dtype == 'object'
    is_y_categorical = df[y_col].dtype == 'object'

    # Only create a grouped chart if BOTH are categorical (e.g., gender vs. major)
    if is_x_categorical and is_y_categorical:
        # --- "color for gender vs major" plot ---
        grouped = pd.crosstab(df[x_col], df[y_col])
        
        n_groups = len(grouped)
        n_categories = len(grouped.columns)
        
        # Calculate bar width and positions
        total_width_per_group = 0.8
        bar_width = total_width_per_group / n_categories
        x_indices = np.arange(n_groups) # The center position for each group
        
        # Loop for each category (e.g., 'Male', 'Female') and plot its bars
        for i, category in enumerate(grouped.columns):
            # Calculate the position for this specific bar
            offset = (i - (n_categories - 1) / 2) * bar_width
            bar_positions = x_indices + offset
            
            counts = grouped[category]
            ax.bar(bar_positions, counts, width=bar_width, label=category, alpha=0.8)

        ax.set_ylabel("Count")
        ax.legend(title=y_col, loc='upper right')
        
        # Set the x-ticks to be in the center of the groups
        ax.set_xticks(x_indices)
        ax.set_xticklabels(grouped.index, rotation=45, ha='right')

    else:
        # --- This is the simple bar chart (e.g., "gender vs. age") ---
        grouped_data = df.groupby(x_col)[y_col].count()
        
        # --- THIS IS THE FIX ---
        # Get labels and positions
        x_labels = grouped_data.index.astype(str)
        x_indices = np.arange(len(x_labels))

        # Plot the bar
        ax.bar(x_indices, grouped_data.values, color='skyblue')
        ax.set_ylabel("Count")
        
        # NOW we set the ticks and labels explicitly
        ax.set_xticks(x_indices) # Tell matplotlib where the ticks are
        ax.set_xticklabels(x_labels, rotation=45, ha='right') # NOW set the labels
        # --- END FIX ---


# --- 7. When user clicks "Generate Plot" (Simple Version) ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)
        
        # --- Close all old, stray plots to prevent the black bar ---
        plt.close('all') 

        # --- 1. Get User Selections ---
        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value
        color_col = color_dropdown.value
        
        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid X/Y columns.")
            return

        # --- 2. Load Full Data ---
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # --- 3. Apply Age Filter (if 'age' column exists) ---
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df.dropna(subset=['age'])
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- 4. Create Plot ---
        fig, ax = plt.subplots(figsize=(7, 4))

        if ptype == "Scatter":
            plot_scatter_on_axis(ax, df, x_col, y_col, color_col, best_fit.value)
        else:
            if 'age' in df.columns:
                df['age'] = df['age'].round().astype(int)
            plot_bar_on_axis(ax, df, x_col, y_col)

        # --- 5. Final Plot Formatting ---
        ax.xaxis.tick_bottom()
        ax.xaxis.set_label_position('bottom')
        ax.yaxis.tick_left()
        ax.yaxis.set_label_position('left')
        
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)
        
        plt.tight_layout()

        # --- 6. Show Image in Output ---
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        
        # We still close the *current* fig, just to be tidy
        plt.close(fig)

# --- 8. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='Svalue')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 9. Display Interface ---
controls = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    # Swapped dropdowns as requested
    widgets.HBox([y_dropdown, x_dropdown, color_dropdown]),
    widgets.HBox([generate_button, best_fit])
])

ui = widgets.VBox([controls, output_area])

# Trigger the update once to populate the column dropdowns
update_columns(None)

display(ui)

VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', '…

In [13]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker # For integer axis ticks
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
color_dropdown = widgets.Dropdown(description="Color by:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update columns when user picks a table or plot type ---
def update_columns(change):
    with output_area:
        pass # Keep UI stable
    
    table = table_dropdown.value
    ptype = plot_type.value

    if not table:
        return

    # Load sample data to find column types
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 50', conn)
    conn.close()

    # Get different types of columns
    all_cols = df.columns.tolist()
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    color_options = ['None'] + categorical_cols # Add 'None' option

    if ptype == "Scatter":
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
        color_dropdown.options = color_options
        color_dropdown.layout.display = 'flex' 
        best_fit.layout.display = 'flex'
    else:
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols
        color_dropdown.layout.display = 'none'
        best_fit.layout.display = 'none'

    # Set default values to prevent errors
    if x_dropdown.options:
        x_dropdown.value = x_dropdown.options[0]
    if y_dropdown.options and len(y_dropdown.options) > 1:
        y_dropdown.value = y_dropdown.options[1]
    elif y_dropdown.options:
        y_dropdown.value = y_dropdown.options[0]
    if color_dropdown.options:
        color_dropdown.value = 'None'


# --- 6. Plotting Helper Functions (Cleaner) ---

def plot_scatter_on_axis(ax, df, x_col, y_col, color_col, add_best_fit):
    """Creates a scatter plot, now with optional coloring and integer axis ticks."""
    
    df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No numeric data to plot.", ha='center', va='center')
        return

    # --- Color Logic ---
    if color_col and color_col != 'None' and color_col in df.columns:
        df = df.dropna(subset=[color_col]) 
        categories = df[color_col].unique()
        
        for category in categories:
            df_group = df[df[color_col] == category]
            ax.scatter(df_group[x_col], df_group[y_col], alpha=0.7, label=str(category))
        if len(categories) > 0:
            ax.legend(title=color_col)
    else:
        ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

    ax.set_ylabel(y_col)

    # --- Best-Fit Line Logic ---
    if add_best_fit and len(df) > 1:
        m, b = np.polyfit(df[x_col], df[y_col], 1)
        ax.plot(df[x_col], m * df[x_col] + b, color='red', linewidth=2)
        
    # --- Set Integer Ticks for 'age' ---
    if x_col == 'age':
        ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
    if y_col == 'age':
        ax.yaxis.set_major_locator(mticker.MultipleLocator(1))


def plot_bar_on_axis(ax, df, x_col, y_col):
    """Creates a bar plot. 'age' data should be rounded before calling this."""
    df = df.dropna(subset=[x_col, y_col])
    
    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    is_x_categorical = df[x_col].dtype == 'object'
    is_y_categorical = df[y_col].dtype == 'object'

    # Only create a grouped chart if BOTH are categorical (e.g., gender vs. major)
    if is_x_categorical and is_y_categorical:
        # --- "color for gender vs major" plot ---
        grouped = pd.crosstab(df[x_col], df[y_col])
        
        n_groups = len(grouped)
        n_categories = len(grouped.columns)
        
        # Calculate bar width and positions
        total_width_per_group = 0.8
        bar_width = total_width_per_group / n_categories
        x_indices = np.arange(n_groups) # The center position for each group
        
        # Loop for each category (e.g., 'Male', 'Female') and plot its bars
        for i, category in enumerate(grouped.columns):
            # Calculate the position for this specific bar
            offset = (i - (n_categories - 1) / 2) * bar_width
            bar_positions = x_indices + offset
            
            counts = grouped[category]
            ax.bar(bar_positions, counts, width=bar_width, label=category, alpha=0.8)

        ax.set_ylabel("Count")
        ax.legend(title=y_col, loc='upper right')
        
        # Set the x-ticks to be in the center of the groups
        ax.set_xticks(x_indices)
        ax.set_xticklabels(grouped.index, rotation=45, ha='right')

    else:
        # --- This is the simple bar chart (e.g., "gender vs. age") ---
        grouped_data = df.groupby(x_col)[y_col].count()
        
        # --- THIS IS THE FIX ---
        # Get labels and positions
        x_labels = grouped_data.index.astype(str)
        x_indices = np.arange(len(x_labels))

        # Plot the bar
        ax.bar(x_indices, grouped_data.values, color='skyblue')
        ax.set_ylabel("Count")
        
        # NOW we set the ticks and labels explicitly
        ax.set_xticks(x_indices) # Tell matplotlib where the ticks are
        ax.set_xticklabels(x_labels, rotation=45, ha='right') # NOW set the labels
        # --- END FIX ---


# --- 7. When user clicks "Generate Plot" (Simple Version) ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)
        
        # --- Close all old, stray plots to prevent the black bar ---
        plt.close('all') 

        # --- 1. Get User Selections ---
        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value
        color_col = color_dropdown.value
        
        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid X/Y columns.")
            return

        # --- 2. Load Full Data ---
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # --- 3. Apply Age Filter (if 'age' column exists) ---
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df.dropna(subset=['age'])
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        # --- 4. Create Plot ---
        fig, ax = plt.subplots(figsize=(7, 4))

        if ptype == "Scatter":
            plot_scatter_on_axis(ax, df, x_col, y_col, color_col, best_fit.value)
        else:
            if 'age' in df.columns:
                df['age'] = df['age'].round().astype(int)
            plot_bar_on_axis(ax, df, x_col, y_col)

        # --- 5. Final Plot Formatting ---
        ax.xaxis.tick_bottom()
        ax.xaxis.set_label_position('bottom')
        ax.yaxis.tick_left()
        ax.yaxis.set_label_position('left')
        
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)
        
        plt.tight_layout()

        # --- 6. Show Image in Output ---
        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        
        # We still close the *current* fig, just to be tidy
        plt.close(fig)

# --- 8. Link Widgets to Events ---
table_dropdown.observe(update_columns, names='Svalue')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 9. Display Interface ---
controls = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    # Swapped dropdowns as requested
    widgets.HBox([y_dropdown, x_dropdown, color_dropdown]),
    widgets.HBox([generate_button, best_fit])
])

ui = widgets.VBox([controls, output_area])

# Trigger the update once to populate the column dropdowns
update_columns(None)

display(ui)

VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', '…

In [14]:
def plot_bar(ax, df, x_col, y_col):
    """Improved bar plot — auto-spaces labels if there are many X values."""
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    is_x_cat = df[x_col].dtype == 'object'
    is_y_cat = df[y_col].dtype == 'object'

    if is_x_cat and is_y_cat:
        # --- Grouped Bar Chart ---
        grouped = pd.crosstab(df[x_col], df[y_col])
        grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8)
        ax.legend(title=y_col)
        plt.xticks(rotation=45, ha='right')
        ax.set_ylabel("Count")

    else:
        # --- Single Bar Chart (e.g., age vs student_id or similar) ---
        grouped = df.groupby(x_col)[y_col].count().sort_index()
        x_labels = grouped.index.astype(str)
        x_indices = np.arange(len(x_labels))

        # Thin bars if there are many
        bar_width = max(0.8 / np.log1p(len(x_labels)), 0.2)
        ax.bar(x_indices, grouped.values, color='skyblue', width=bar_width)
        ax.set_ylabel("Count")

        # --- Smarter tick spacing ---
        if len(x_labels) > 20:
            # Show every Nth label
            step = max(len(x_labels) // 20, 1)
            ax.set_xticks(x_indices[::step])
            ax.set_xticklabels(x_labels[::step], rotation=45, ha='right')
        else:
            ax.set_xticks(x_indices)
            ax.set_xticklabels(x_labels, rotation=45, ha='right')

        # Avoid clipping labels
        ax.margins(x=0.01)

        # If the X column looks like an ID, make note
        if "id" in x_col.lower():
            ax.set_xlabel(f"{x_col} (sampled labels for readability)")


In [15]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output
import io

# --- 2. Database Setup ---
DB_NAME = 'Dataset.db'

# --- 3. Get All Tables from Database ---
try:
    conn = sqlite3.connect(DB_NAME)
    tables = pd.read_sql_query(
        "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
        conn
    )
    table_list = tables['name'].tolist()
    conn.close()
except Exception as e:
    print("Error loading database:", e)
    table_list = []

# --- 4. Create Widgets ---
table_dropdown = widgets.Dropdown(options=table_list, description="Table:")
plot_type = widgets.Dropdown(options=["Scatter", "Bar"], description="Plot Type:")
x_dropdown = widgets.Dropdown(description="X-Axis:")
y_dropdown = widgets.Dropdown(description="Y-Axis:")
color_dropdown = widgets.Dropdown(description="Color by:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style='success')
output_area = widgets.Output()

# --- 5. Update Columns When Table or Plot Type Changes ---
def update_columns(change):
    with output_area:
        pass  # Keep UI stable

    table = table_dropdown.value
    ptype = plot_type.value
    if not table:
        return

    # Load sample data
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f'SELECT * FROM "{table}" LIMIT 50', conn)
    conn.close()

    all_cols = df.columns.tolist()
    numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
    color_options = ['None'] + categorical_cols

    if ptype == "Scatter":
        x_dropdown.options = numeric_cols
        y_dropdown.options = numeric_cols
        color_dropdown.options = color_options
        color_dropdown.layout.display = 'flex'
        best_fit.layout.display = 'flex'
    else:
        x_dropdown.options = all_cols
        y_dropdown.options = all_cols
        color_dropdown.layout.display = 'none'
        best_fit.layout.display = 'none'

    if x_dropdown.options:
        x_dropdown.value = x_dropdown.options[0]
    if y_dropdown.options and len(y_dropdown.options) > 1:
        y_dropdown.value = y_dropdown.options[1]
    elif y_dropdown.options:
        y_dropdown.value = y_dropdown.options[0]
    if color_dropdown.options:
        color_dropdown.value = 'None'


# --- 6. Plotting Helper Functions ---
def plot_scatter(ax, df, x_col, y_col, color_col, add_best_fit):
    df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
    df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
    df = df.dropna(subset=[x_col, y_col])

    if df.empty:
        ax.text(0.5, 0.5, "No numeric data to plot.", ha='center', va='center')
        return

    # Color handling
    if color_col and color_col != 'None' and color_col in df.columns:
        df = df.dropna(subset=[color_col])
        categories = df[color_col].unique()
        for cat in categories:
            ax.scatter(df[df[color_col] == cat][x_col],
                       df[df[color_col] == cat][y_col],
                       alpha=0.7, label=str(cat))
        ax.legend(title=color_col)
    else:
        ax.scatter(df[x_col], df[y_col], color='blue', alpha=0.7)

    # Best-fit line
    if add_best_fit and len(df) > 1:
        m, b = np.polyfit(df[x_col], df[y_col], 1)
        ax.plot(df[x_col], m * df[x_col] + b, color='red', linewidth=2)

    if x_col == 'age':
        ax.xaxis.set_major_locator(mticker.MultipleLocator(1))
    if y_col == 'age':
        ax.yaxis.set_major_locator(mticker.MultipleLocator(1))

    ax.set_ylabel(y_col)


def plot_bar(ax, df, x_col, y_col):
    df = df.dropna(subset=[x_col, y_col])
    if df.empty:
        ax.text(0.5, 0.5, "No data to plot.", ha='center', va='center')
        return

    is_x_cat = df[x_col].dtype == 'object'
    is_y_cat = df[y_col].dtype == 'object'

    if is_x_cat and is_y_cat:
        # Grouped bar (e.g., gender vs major)
        grouped = pd.crosstab(df[x_col], df[y_col])
        grouped.plot(kind='bar', ax=ax, width=0.8, alpha=0.8)
        ax.legend(title=y_col)
        plt.xticks(rotation=45, ha='right')
        ax.set_ylabel("Count")
    else:
        # Standard bar chart (e.g., age vs student_id)
        grouped = df.groupby(x_col)[y_col].count().sort_index()
        x_labels = grouped.index.astype(str)
        x_indices = np.arange(len(x_labels))

        # Thin bars if many labels
        bar_width = max(0.8 / np.log1p(len(x_labels)), 0.2)
        ax.bar(x_indices, grouped.values, color='skyblue', width=bar_width)
        ax.set_ylabel("Count")

        # Smarter tick spacing
        if len(x_labels) > 20:
            step = max(len(x_labels) // 20, 1)
            ax.set_xticks(x_indices[::step])
            ax.set_xticklabels(x_labels[::step], rotation=45, ha='right')
        else:
            ax.set_xticks(x_indices)
            ax.set_xticklabels(x_labels, rotation=45, ha='right')

        ax.margins(x=0.01)
        if "id" in x_col.lower():
            ax.set_xlabel(f"{x_col} (sampled labels for readability)")


# --- 7. Generate Plot Function ---
def generate_plot(b):
    with output_area:
        clear_output(wait=True)
        plt.close('all')

        table = table_dropdown.value
        x_col = x_dropdown.value
        y_col = y_dropdown.value
        ptype = plot_type.value
        color_col = color_dropdown.value

        if not table or not x_col or not y_col:
            print("⚠️ Please select a table and valid X/Y columns.")
            return

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'SELECT * FROM "{table}"', conn)
        conn.close()

        # Age filter
        if 'age' in df.columns:
            df['age'] = pd.to_numeric(df['age'], errors='coerce')
            df = df.dropna(subset=['age'])
            df = df[(df['age'] >= 18) & (df['age'] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        if df.empty:
            print("No data available after filtering.")
            return

        fig, ax = plt.subplots(figsize=(7, 4))

        if ptype == "Scatter":
            plot_scatter(ax, df, x_col, y_col, color_col, best_fit.value)
        else:
            if 'age' in df.columns:
                df['age'] = df['age'].round().astype(int)
            plot_bar(ax, df, x_col, y_col)

        # Final formatting
        ax.set_xlabel(x_col)
        ax.set_title(f"{y_col} vs. {x_col} ({ptype})")
        ax.grid(True, linestyle='--', alpha=0.5)
        plt.tight_layout()

        buf = io.BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        display(widgets.Image(value=buf.read(), format='png'))
        plt.close(fig)

# --- 8. Link Widgets ---
table_dropdown.observe(update_columns, names='value')
plot_type.observe(update_columns, names='value')
generate_button.on_click(generate_plot)

# --- 9. Display UI ---
controls = widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([y_dropdown, x_dropdown, color_dropdown]),
    widgets.HBox([generate_button, best_fit])
])

ui = widgets.VBox([controls, output_area])
update_columns(None)
display(ui)


VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', '…

In [16]:
# Simple Jupyter Plot App
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = tables["name"].tolist()
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table changes
def update_columns(change):
    table = table_dropdown.value
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    cols = df.columns.tolist()
    x_axis.options = cols
    y_axis.options = cols
    if len(cols) > 1:
        x_axis.value = cols[0]
        y_axis.value = cols[1]

table_dropdown.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if column exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        # Drop missing data
        df = df.dropna(subset=[x_col, y_col])
        if df.empty:
            print("No data to plot.")
            return

        plt.figure(figsize=(7,4))
        
        if kind == "Scatter":
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            if best_fit.value:
                # Add simple best-fit line
                if df[x_col].dtype != 'object' and df[y_col].dtype != 'object':
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")
        else:
            # Group by X and count Y
            counts = df.groupby(x_col)[y_col].count()
            plt.bar(counts.index.astype(str), counts.values, color="skyblue")
            
            # Make X labels readable if too many
            if len(counts) > 20:
                step = max(1, len(counts)//20)
                plt.xticks(counts.index[::step], rotation=45)
            else:
                plt.xticks(rotation=45)
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.ylabel(y_col if kind=="Scatter" else "Count")
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display all controls
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]),
    widgets.HBox([generate_button, best_fit]),
    output
]))


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('sqlite_sequence', 'Students', 'Academic…

In [17]:
# Simple Jupyter Plot App
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]  # remove sqlite_sequence
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table changes
def update_columns(change):
    table = table_dropdown.value
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    cols = df.columns.tolist()
    x_axis.options = cols
    y_axis.options = cols
    if len(cols) > 1:
        x_axis.value = cols[0]
        y_axis.value = cols[1]

table_dropdown.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if column exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        # Drop missing data
        df = df.dropna(subset=[x_col, y_col])
        if df.empty:
            print("No data to plot.")
            return

        plt.figure(figsize=(7,4))
        
        if kind == "Scatter":
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            if best_fit.value:
                # Add simple best-fit line
                if df[x_col].dtype != 'object' and df[y_col].dtype != 'object':
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")
        else:
            # Group by X and count Y
            counts = df.groupby(x_col)[y_col].count()
            plt.bar(counts.index.astype(str), counts.values, color="skyblue")
            
            # Make X labels readable if too many
            if len(counts) > 20:
                step = max(1, len(counts)//20)
                plt.xticks(counts.index[::step], rotation=45)
            else:
                plt.xticks(rotation=45)
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.ylabel(y_col if kind=="Scatter" else "Count")
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display all controls
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]),
    widgets.HBox([generate_button, best_fit]),
    output
]))


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [18]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table changes
def update_columns(change):
    table = table_dropdown.value
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    cols = df.columns.tolist()
    x_axis.options = cols
    y_axis.options = cols
    if len(cols) > 1:
        x_axis.value = cols[0]
        y_axis.value = cols[1]

table_dropdown.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        df = df.dropna(subset=[x_col, y_col])
        if df.empty:
            print("No data to plot.")
            return

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            if best_fit.value:
                if df[x_col].dtype != 'object' and df[y_col].dtype != 'object':
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Case 1: Both are categorical (grouped bar)
            if df[x_col].dtype == "object" and df[y_col].dtype == "object":
                grouped = pd.crosstab(df[x_col], df[y_col])  # counts per category
                ages = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns)
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        ages + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.xticks(ages + bar_width*(len(grouped.columns)/2 - 0.5), grouped.index, rotation=45)
                plt.legend(title=y_col)
                plt.ylabel("Count")

            # Case 2: One numeric, one categorical (simple count)
            else:
                counts = df.groupby(x_col)[y_col].count()
                plt.bar(counts.index.astype(str), counts.values, color="skyblue")
                if len(counts) > 20:
                    step = max(1, len(counts)//20)
                    plt.xticks(counts.index[::step], rotation=45)
                else:
                    plt.xticks(rotation=45)
                plt.ylabel("Count")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]),
    widgets.HBox([generate_button, best_fit]),
    output
]))


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [19]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table changes
def update_columns(change):
    table = table_dropdown.value
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    cols = df.columns.tolist()
    x_axis.options = cols
    y_axis.options = cols
    if len(cols) > 1:
        x_axis.value = cols[0]
        y_axis.value = cols[1]

table_dropdown.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        df = df.dropna(subset=[x_col, y_col])
        if df.empty:
            print("No data to plot.")
            return

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            if best_fit.value:
                if df[x_col].dtype != 'object' and df[y_col].dtype != 'object':
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Case 1: Both are categorical (grouped bar)
            if df[x_col].dtype == "object" and df[y_col].dtype == "object":
                grouped = pd.crosstab(df[x_col], df[y_col])  # counts per category
                ages = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns)
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        ages + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.xticks(ages + bar_width*(len(grouped.columns)/2 - 0.5), grouped.index, rotation=45)
                plt.legend(title=y_col)
                plt.ylabel("Count")

            # Case 2: One numeric, one categorical (simple count)
            else:
                counts = df.groupby(x_col)[y_col].count()
                plt.bar(counts.index.astype(str), counts.values, color="skyblue")
                if len(counts) > 20:
                    step = max(1, len(counts)//20)
                    plt.xticks(counts.index[::step], rotation=45)
                else:
                    plt.xticks(rotation=45)
                plt.ylabel("Count")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]),
    widgets.HBox([generate_button, best_fit]),
    output
]))


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [20]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table changes
def update_columns(change):
    table = table_dropdown.value
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    cols = df.columns.tolist()
    x_axis.options = cols
    y_axis.options = cols
    if len(cols) > 1:
        x_axis.value = cols[0]
        y_axis.value = cols[1]

table_dropdown.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        df = df.dropna(subset=[x_col, y_col])
        if df.empty:
            print("No data to plot.")
            return

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            if best_fit.value:
                if df[x_col].dtype != 'object' and df[y_col].dtype != 'object':
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Convert age to string to allow grouped bars (treat numeric as category)
            if x_col == "age":
                df["age"] = df["age"].astype(int).astype(str)
            
            # Grouped bar: count of Y per X
            grouped = pd.crosstab(df[x_col], df[y_col])
            groups = np.arange(len(grouped))
            bar_width = 0.8 / len(grouped.columns)

            for i, cat in enumerate(grouped.columns):
                plt.bar(
                    groups + i * bar_width,
                    grouped[cat].values,
                    width=bar_width,
                    label=str(cat),
                    alpha=0.8
                )

            plt.xticks(groups + bar_width*(len(grouped.columns)/2 - 0.5), grouped.index, rotation=45)
            plt.ylabel("Count")
            plt.legend(title=y_col)

        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]),
    widgets.HBox([generate_button, best_fit]),
    output
]))


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [21]:
# Simple Jupyter Plot App with explicit age vs gender bars
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Database file name
DB_NAME = "Dataset.db"

# Read table names (skip sqlite_sequence)
conn = sqlite3.connect(DB_NAME)
tables_df = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables_df["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Populate column dropdowns when table changes
def update_columns(change):
    table = table_dropdown.value
    if not table:
        return
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()
    cols = df.columns.tolist()
    x_axis.options = cols
    y_axis.options = cols
    if len(cols) > 1:
        x_axis.value = cols[0]
        y_axis.value = cols[1]

table_dropdown.observe(update_columns, names="value")

# Plotting
def generate_plot(b):
    with output:
        clear_output(wait=True)
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value

        if not table or not x_col or not y_col:
            print("Select a table and both X and Y columns.")
            return

        # Load full data
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Age filter if present
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        # Drop rows missing selected columns
        df = df.dropna(subset=[x_col, y_col])
        if df.empty:
            print("No data to plot after filtering.")
            return

        plt.figure(figsize=(8,4))

        if kind == "Scatter":
            # Simple scatter
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            # Best-fit line only when both numeric
            if best_fit.value:
                if pd.api.types.is_numeric_dtype(df[x_col]) and pd.api.types.is_numeric_dtype(df[y_col]):
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col] + b, color="red")
            plt.xlabel(x_col)
            plt.ylabel(y_col)

        else:
            # ---------- Special explicit case: age (X) vs gender (Y) ----------
            if x_col == "age" and y_col.lower() == "gender":
                # Ensure age is integer and string for ticks
                df["age"] = df["age"].round().astype(int)
                # Keep only the ages 18-22 in sorted order (if present)
                ages = sorted(df["age"].unique())
                ages = [a for a in ages if 18 <= a <= 22]

                # Get counts per age for Male and Female (case-insensitive)
                df["gender_norm"] = df["gender"].astype(str).str.strip().str.lower()
                male_counts = []
                female_counts = []
                for a in ages:
                    sub = df[df["age"] == a]
                    male_counts.append((sub["gender_norm"] == "male").sum())
                    female_counts.append((sub["gender_norm"] == "female").sum())

                x = np.arange(len(ages))
                width = 0.35
                plt.bar(x - width/2, male_counts, width=width, label="Male")
                plt.bar(x + width/2, female_counts, width=width, label="Female")
                plt.xticks(x, [str(a) for a in ages], rotation=45)
                plt.ylabel("Count")
                plt.xlabel("age")
                plt.legend()
                plt.title("gender counts by age (Male vs Female)")
            else:
                # ---------- Generic bar: use crosstab to make grouped bars ----------
                # Treat numeric X as categorical for bar grouping
                if pd.api.types.is_numeric_dtype(df[x_col]):
                    df[x_col] = df[x_col].round().astype(int).astype(str)
                else:
                    df[x_col] = df[x_col].astype(str)

                df[y_col] = df[y_col].astype(str)
                grouped = pd.crosstab(df[x_col], df[y_col])
                if grouped.shape[1] == 1:
                    # Single series: just plot counts
                    counts = grouped.iloc[:,0].values
                    labels = grouped.index.tolist()
                    plt.bar(labels, counts, color="skyblue")
                    plt.xticks(rotation=45)
                    plt.ylabel("Count")
                    plt.xlabel(x_col)
                else:
                    # Multiple categories: grouped bars
                    n_groups = len(grouped)
                    n_cats = grouped.shape[1]
                    bar_width = 0.8 / n_cats
                    indices = np.arange(n_groups)
                    for i, col in enumerate(grouped.columns):
                        plt.bar(indices + i*bar_width, grouped[col].values, width=bar_width, label=str(col))
                    plt.xticks(indices + (n_cats-1)*bar_width/2, grouped.index, rotation=45)
                    plt.ylabel("Count")
                    plt.xlabel(x_col)
                    plt.legend(title=y_col)

        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Initialize columns for the first table if present
if table_names:
    table_dropdown.value = table_names[0]
    update_columns(None)

# Display the UI
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]),
    widgets.HBox([generate_button, best_fit]),
    output
]))


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [22]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table changes
def update_columns(change):
    table = table_dropdown.value
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    cols = df.columns.tolist()
    x_axis.options = cols
    y_axis.options = cols
    if len(cols) > 1:
        x_axis.value = cols[0]
        y_axis.value = cols[1]

table_dropdown.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        df = df.dropna(subset=[x_col, y_col])
        if df.empty:
            print("No data to plot.")
            return

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            if best_fit.value:
                if df[x_col].dtype != 'object' and df[y_col].dtype != 'object':
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Case 1: Both are categorical (grouped bar)
            if df[x_col].dtype == "object" and df[y_col].dtype == "object":
                grouped = pd.crosstab(df[x_col], df[y_col])  # counts per category
                ages = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns)
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        ages + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.xticks(ages + bar_width*(len(grouped.columns)/2 - 0.5), grouped.index, rotation=45)
                plt.legend(title=y_col)
                plt.ylabel("Count")

            # Case 2: One numeric, one categorical (simple count)
            else:
                counts = df.groupby(x_col)[y_col].count()
                plt.bar(counts.index.astype(str), counts.values, color="skyblue")
                if len(counts) > 20:
                    step = max(1, len(counts)//20)
                    plt.xticks(counts.index[::step], rotation=45)
                else:
                    plt.xticks(rotation=45)
                plt.ylabel("Count")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]),
    widgets.HBox([generate_button, best_fit]),
    output
]))


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [23]:
# Simple Data Plot App for Beginners (with Male/Female per Age)
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table changes
def update_columns(change):
    table = table_dropdown.value
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    cols = df.columns.tolist()
    x_axis.options = cols
    y_axis.options = cols
    if len(cols) > 1:
        x_axis.value = cols[0]
        y_axis.value = cols[1]

table_dropdown.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        df = df.dropna(subset=[x_col, y_col])
        if df.empty:
            print("No data to plot.")
            return

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            if best_fit.value:
                if df[x_col].dtype != 'object' and df[y_col].dtype != 'object':
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Special case: gender vs age (show male/female bars per age)
            if (x_col == "age" and y_col == "gender") or (x_col == "gender" and y_col == "age"):
                # Make sure "age" is numeric
                df["age"] = pd.to_numeric(df["age"], errors="coerce")
                df = df.dropna(subset=["age", "gender"])
                
                grouped = df.groupby(["age", "gender"]).size().unstack(fill_value=0)
                ages = np.arange(len(grouped))
                bar_width = 0.35

                # Male and Female bars
                if "Male" in grouped.columns:
                    plt.bar(ages - bar_width/2, grouped["Male"], width=bar_width, label="Male", color="skyblue")
                if "Female" in grouped.columns:
                    plt.bar(ages + bar_width/2, grouped["Female"], width=bar_width, label="Female", color="pink")

                plt.xticks(ages, grouped.index.astype(int))
                plt.ylabel("Count")
                plt.xlabel("Age")
                plt.legend(title="Gender")

            # Case 1: Both categorical (grouped bar)
            elif df[x_col].dtype == "object" and df[y_col].dtype == "object":
                grouped = pd.crosstab(df[x_col], df[y_col])
                labels = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns)
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        labels + i * bar_width,
                        grouped[col].values,
                        width=bar_width,
                        label=str(col),
                        alpha=0.8
                    )
                
                plt.xticks(labels + bar_width*(len(grouped.columns)/2 - 0.5), grouped.index, rotation=45)
                plt.legend(title=y_col)
                plt.ylabel("Count")

            # Case 2: One numeric, one categorical (simple count)
            else:
                counts = df.groupby(x_col)[y_col].count()
                plt.bar(counts.index.astype(str), counts.values, color="skyblue")
                if len(counts) > 20:
                    step = max(1, len(counts)//20)
                    plt.xticks(counts.index[::step], rotation=45)
                else:
                    plt.xticks(rotation=45)
                plt.ylabel("Count")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]),
    widgets.HBox([generate_button, best_fit]),
    output
]))


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [24]:
# Simple Data Plot App for Beginners (fixed gender legend issue)
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table changes
def update_columns(change):
    table = table_dropdown.value
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    cols = df.columns.tolist()
    x_axis.options = cols
    y_axis.options = cols
    if len(cols) > 1:
        x_axis.value = cols[0]
        y_axis.value = cols[1]

table_dropdown.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if available
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        df = df.dropna(subset=[x_col, y_col])
        if df.empty:
            print("No data to plot.")
            return

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            if best_fit.value:
                if df[x_col].dtype != 'object' and df[y_col].dtype != 'object':
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Special case: gender vs age
            if (x_col == "age" and y_col == "gender") or (x_col == "gender" and y_col == "age"):
                df["age"] = pd.to_numeric(df["age"], errors="coerce")
                df = df.dropna(subset=["age", "gender"])
                
                grouped = df.groupby(["age", "gender"]).size().unstack(fill_value=0)
                ages = np.arange(len(grouped))
                bar_width = 0.35

                bars_drawn = []

                if "Male" in grouped.columns:
                    bars_drawn.append(
                        plt.bar(ages - bar_width/2, grouped["Male"], width=bar_width, label="Male", color="skyblue")
                    )
                if "Female" in grouped.columns:
                    bars_drawn.append(
                        plt.bar(ages + bar_width/2, grouped["Female"], width=bar_width, label="Female", color="pink")
                    )

                plt.xticks(ages, grouped.index.astype(int))
                plt.ylabel("Count")
                plt.xlabel("Age")

                if bars_drawn:  # only show legend if bars exist
                    plt.legend(title="Gender")

            # Case 1: Both categorical (grouped bar)
            elif df[x_col].dtype == "object" and df[y_col].dtype == "object":
                grouped = pd.crosstab(df[x_col], df[y_col])
                labels = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns)
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        labels + i * bar_width,
                        grouped[col].values,
                        width=bar_width,
                        label=str(col),
                        alpha=0.8
                    )
                
                plt.xticks(labels + bar_width*(len(grouped.columns)/2 - 0.5), grouped.index, rotation=45)
                plt.legend(title=y_col)
                plt.ylabel("Count")

            # Case 2: One numeric, one categorical (simple count)
            else:
                counts = df.groupby(x_col)[y_col].count()
                plt.bar(counts.index.astype(str), counts.values, color="skyblue")
                if len(counts) > 20:
                    step = max(1, len(counts)//20)
                    plt.xticks(counts.index[::step], rotation=45)
                else:
                    plt.xticks(rotation=45)
                plt.ylabel("Count")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]),
    widgets.HBox([generate_button, best_fit]),
    output
]))


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [25]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table changes
def update_columns(change):
    table = table_dropdown.value
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    cols = df.columns.tolist()
    x_axis.options = cols
    y_axis.options = cols
    if len(cols) > 1:
        x_axis.value = cols[0]
        y_axis.value = cols[1]

table_dropdown.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        df = df.dropna(subset=[x_col, y_col])
        if df.empty:
            print("No data to plot.")
            return

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            if best_fit.value:
                if df[x_col].dtype != 'object' and df[y_col].dtype != 'object':
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Case 1: Both are categorical (grouped bar)
            if df[x_col].dtype == "object" and df[y_col].dtype == "object":
                grouped = pd.crosstab(df[x_col], df[y_col])  # counts per category
                ages = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns)
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        ages + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.xticks(ages + bar_width*(len(grouped.columns)/2 - 0.5), grouped.index, rotation=45)
                plt.legend(title=y_col)
                plt.ylabel("Count")

            # Case 2: One numeric, one categorical (simple count)
            else:
                counts = df.groupby(x_col)[y_col].count()
                plt.bar(counts.index.astype(str), counts.values, color="skyblue")
                if len(counts) > 20:
                    step = max(1, len(counts)//20)
                    plt.xticks(counts.index[::step], rotation=45)
                else:
                    plt.xticks(rotation=45)
                plt.ylabel("Count")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]),
    widgets.HBox([generate_button, best_fit]),
    output
]))


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [26]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table changes
def update_columns(change):
    table = table_dropdown.value
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    cols = df.columns.tolist()
    x_axis.options = cols
    y_axis.options = cols
    if len(cols) > 1:
        x_axis.value = cols[0]
        y_axis.value = cols[1]

table_dropdown.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            # --- Added rounding for bar charts ---
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            if kind == "Bar":
                 df['age'] = df['age'].round().astype(int)
            print(f"Filter applied: {len(df)} students aged 18–22.")

        df = df.dropna(subset=[x_col, y_col])
        if df.empty:
            print("No data to plot.")
            return

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric for scatter
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col])
            
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if best_fit.value:
                # Check dtypes again after conversion
                if df[x_col].dtype != 'object' and df[y_col].dtype != 'object' and len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # --- THIS IS THE FIX ---
            # If Y is categorical (like 'gender'), make a grouped bar chart
            if df[y_col].dtype == "object":
                # Case 1: Grouped bar chart (e.g., X=age, Y=gender)
                grouped = pd.crosstab(df[x_col], df[y_col]) # counts per category
                ages = np.arange(len(grouped))
                # Use a narrower bar width for grouped charts
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        ages + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                # Center the x-ticks between the groups
                plt.xticks(ages + bar_width*(len(grouped.columns)/2 - 0.5), grouped.index, rotation=45, ha="right")
                plt.legend(title=y_col)
                plt.ylabel("Count")

            # --- THIS LOGIC IS ALSO IMPROVED ---
            # Case 2: Simple count (e.g., X=age, Y=student_id)
            else:
                counts = df.groupby(x_col)[y_col].count()
                
                # Get labels and integer positions for plotting
                labels = counts.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, counts.values, color="skyblue")
                plt.ylabel("Count")

                # Set ticks
                if len(positions) > 50:
                    # If too many, hide labels to prevent black bar
                    plt.xticks([])
                elif len(positions) > 20:
                    # If a lot, show every Nth label
                    step = max(1, len(positions) // 20)
                    plt.xticks(positions[::step], labels[::step], rotation=45, ha="right")
                else:
                    # Otherwise, show all
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]),
    widgets.HBox([generate_button, best_fit]),
    output
]))

# Trigger the update once to load the first table's columns
if table_names:
    update_columns(None)

VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [27]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table changes
def update_columns(change):
    table = table_dropdown.value
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    cols = df.columns.tolist()
    x_axis.options = cols
    y_axis.options = cols
    if len(cols) > 1:
        x_axis.value = cols[0]
        y_axis.value = cols[1]

table_dropdown.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            # Added rounding for bar charts
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            if kind == "Bar":
                 df['age'] = df['age'].round().astype(int)
            print(f"Filter applied: {len(df)} students aged 18–22.")

        df = df.dropna(subset=[x_col, y_col])
        if df.empty:
            print("No data to plot.")
            return

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric for scatter
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col])
            
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if best_fit.value:
                # Check dtypes again after conversion
                if df[x_col].dtype != 'object' and df[y_col].dtype != 'object' and len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # If Y is categorical (like 'gender'), make a grouped bar chart
            if df[y_col].dtype == "object":
                # Case 1: Grouped bar chart (e.g., X=age, Y=gender)
                grouped = pd.crosstab(df[x_col], df[y_col]) # counts per category
                ages = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        ages + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count")

                # --- THIS IS THE FIX ---
                # Get labels and positions for ticks
                labels = grouped.index.astype(str)
                positions = ages + bar_width*(len(grouped.columns)/2 - 0.5)

                if len(positions) > 50:
                    # If too many, hide labels to prevent black bar
                    plt.xticks([])
                elif len(positions) > 20:
                    # If a lot, show every Nth label
                    step = max(1, len(positions) // 20)
                    plt.xticks(positions[::step], labels[::step], rotation=45, ha="right")
                else:
                    # Otherwise, show all
                    plt.xticks(positions, labels, rotation=45, ha="right")
                # --- END FIX ---

            # Case 2: Simple count (e.g., X=age, Y=student_id)
            else:
                counts = df.groupby(x_col)[y_col].count()
                
                # Get labels and integer positions for plotting
                labels = counts.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, counts.values, color="skyblue")
                plt.ylabel("Count")

                # Set ticks (This logic was already correct)
                if len(positions) > 50:
                    plt.xticks([])
                elif len(positions) > 20:
                    step = max(1, len(positions) // 20)
                    plt.xticks(positions[::step], labels[::step], rotation=45, ha="right")
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]),
    widgets.HBox([generate_button, best_fit]),
    output
]))

# Trigger the update once to load the first table's columns
if table_names:
    update_columns(None)

VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [28]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table changes
def update_columns(change):
    table = table_dropdown.value
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    cols = df.columns.tolist()
    x_axis.options = cols
    y_axis.options = cols
    if len(cols) > 1:
        x_axis.value = cols[0]
        y_axis.value = cols[1]

table_dropdown.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            print(f"Filter applied: {len(df)} students aged 18–22.")

        df = df.dropna(subset=[x_col, y_col])
        if df.empty:
            print("No data to plot.")
            return

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            if best_fit.value:
                if df[x_col].dtype != 'object' and df[y_col].dtype != 'object':
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Case 1: Both are categorical (grouped bar)
            if df[x_col].dtype == "object" and df[y_col].dtype == "object":
                grouped = pd.crosstab(df[x_col], df[y_col])  # counts per category
                ages = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns)
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        ages + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.xticks(ages + bar_width*(len(grouped.columns)/2 - 0.5), grouped.index, rotation=45)
                plt.legend(title=y_col)
                plt.ylabel("Count")

            # Case 2: One numeric, one categorical (simple count)
            else:
                counts = df.groupby(x_col)[y_col].count()
                plt.bar(counts.index.astype(str), counts.values, color="skyblue")
                if len(counts) > 20:
                    step = max(1, len(counts)//20)
                    plt.xticks(counts.index[::step], rotation=45)
                else:
                    plt.xticks(rotation=45)
                plt.ylabel("Count")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]),
    widgets.HBox([generate_button, best_fit]),
    output
]))


VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [29]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# --- THIS FUNCTION IS NOW FIXED ---
# Update dropdowns when table or plot type changes
def update_columns(change):
    table = table_dropdown.value
    ptype = plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
        x_axis.options = numeric_cols
        y_axis.options = numeric_cols
        best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        x_axis.options = all_cols
        y_axis.options = all_cols
        best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if x_axis.options:
        x_axis.value = x_axis.options[0]
    if y_axis.options and len(y_axis.options) > 1:
        y_axis.value = y_axis.options[1]
    elif y_axis.options:
        y_axis.value = y_axis.options[0]

# --- OBSERVERS ARE NOW FIXED ---
table_dropdown.observe(update_columns, names="value")
plot_type.observe(update_columns, names="value") # Added this observer

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            if kind == "Bar":
                 df['age'] = df['age'].round().astype(int)
            print(f"Filter applied: {len(df)} students aged 18–22.")

        # --- THIS IS THE OTHER FIX ---
        # The .dropna() is now moved inside each plot's logic

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return

            # If Y is categorical (like 'gender'), make a grouped bar chart
            if df[y_col].dtype == "object":
                # Case 1: Grouped bar chart (e.g., X=age, Y=gender)
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                ages = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        ages + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count")

                # Fix for clumped labels
                labels = grouped.index.astype(str)
                positions = ages + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Simple count (e.g., X=age, Y=student_id)
            else:
                counts = df.groupby(x_col)[y_col].count()
                labels = counts.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, counts.values, color="skyblue")
                plt.ylabel("Count")

                # Fix for clumped labels
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]), # Kept your original swapped layout
    widgets.HBox([generate_button, best_fit]),
    output
]))

# Trigger the update once to load the first table's columns
if table_names:
    update_columns(None)

VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [30]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table or plot type changes
def update_columns(change):
    table = table_dropdown.value
    ptype = plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
        x_axis.options = numeric_cols
        y_axis.options = numeric_cols
        best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        x_axis.options = all_cols
        y_axis.options = all_cols
        best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if x_axis.options:
        x_axis.value = x_axis.options[0]
    if y_axis.options and len(y_axis.options) > 1:
        y_axis.value = y_axis.options[1]
    elif y_axis.options:
        y_axis.value = y_axis.options[0]

# Observers
table_dropdown.observe(update_columns, names="value")
plot_type.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: {len(df)} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return

            # If Y is categorical (like 'gender'), make a grouped bar chart
            if df[y_col].dtype == "object":
                # Case 1: Grouped bar chart (e.g., X=age, Y=gender)
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                ages = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        ages + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is still a count

                # Fix for clumped labels
                labels = grouped.index.astype(str)
                positions = ages + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # --- THIS IS THE FIX ---
            # Case 2: Y-column is numeric, so show AVERAGE
            else:
                # Calculate the mean (average)
                agg_data = df.groupby(x_col)[y_col].mean()
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                # Set the label to show it's an average
                plt.ylabel(f"Average {y_col}") 

                # Fix for clumped labels (handles X-axis)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            # --- END FIX ---
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]), # Kept your original swapped layout
    widgets.HBox([generate_button, best_fit]),
    output
]))

# Trigger the update once to load the first table's columns
if table_names:
    update_columns(None)

VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [31]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table or plot type changes
def update_columns(change):
    table = table_dropdown.value
    ptype = plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
        x_axis.options = numeric_cols
        y_axis.options = numeric_cols
        best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        x_axis.options = all_cols
        y_axis.options = all_cols
        best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if x_axis.options:
        x_axis.value = x_axis.options[0]
    if y_axis.options and len(y_axis.options) > 1:
        y_axis.value = y_axis.options[1]
    elif y_axis.options:
        y_axis.value = y_axis.options[0]

# Observers
table_dropdown.observe(update_columns, names="value")
plot_type.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: {len(df)} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE NEW, SMARTER LOGIC ---
            is_x_categorical = df[x_col].dtype == 'object'
            is_y_categorical = df[y_col].dtype == 'object'

            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            # This is the grouped "Count" chart
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                ages = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        ages + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = ages + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric and X is categorical (e.g., X=gender, Y=age)
            # This is the "Average" chart
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].mean()
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Average {y_col}") # This is an average

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (e.g., X=age, Y=student_id)
            # This is the simple "Count" chart
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            # --- END NEW LOGIC ---
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]), # Kept your original swapped layout
    widgets.HBox([generate_button, best_fit]),
    output
]))

# Trigger the update once to load the first table's columns
if table_names:
    update_columns(None)

VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [32]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table or plot type changes
def update_columns(change):
    table = table_dropdown.value
    ptype = plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
        x_axis.options = numeric_cols
        y_axis.options = numeric_cols
        best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        x_axis.options = all_cols
        y_axis.options = all_cols
        best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if x_axis.options:
        x_axis.value = x_axis.options[0]
    if y_axis.options and len(y_axis.options) > 1:
        y_axis.value = y_axis.options[1]
    elif y_axis.options:
        y_axis.value = y_axis.options[0]

# Observers
table_dropdown.observe(update_columns, names="value")
plot_type.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df) # Get count before filtering
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Try to convert columns to numeric where possible
            # This helps our "is_categorical" check
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): # Failed to convert (it's text)
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all(): # Failed to convert (it's text)
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE NEW, SMARTER LOGIC ---
            
            # A column is "categorical" if it's an 'object' (text)
            # OR if it's numeric but has a small number of unique values.
            def is_col_categorical(col_name, unique_thresh=25):
                # Check if dtype is object
                if df[col_name].dtype == 'object':
                    return True
                # Check if it's numeric and has few unique values
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            # --- END NEW LOGIC ---
            
            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            # This is the grouped "Count" chart
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                ages = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        ages + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = ages + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric (not categorical) and X is categorical
            # (e.g., X=gender, Y=age)
            # This is the "Average" chart
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].mean()
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Average {y_col}") # This is an average

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (and not categorical)
            # (e.g., X=age, Y=student_id)
            # This is the simple "Count" chart
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            # --- END LOGIC ---
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]), # Kept your original swapped layout
    widgets.HBox([generate_button, best_fit]),
    output
]))

# Trigger the update once to load the first table's columns
if table_names:
    update_columns(None)

VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [33]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table or plot type changes
def update_columns(change):
    table = table_dropdown.value
    ptype = plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
        x_axis.options = numeric_cols
        y_axis.options = numeric_cols
        best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        x_axis.options = all_cols
        y_axis.options = all_cols
        best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if x_axis.options:
        x_axis.value = x_axis.options[0]
    if y_axis.options and len(y_axis.options) > 1:
        y_axis.value = y_axis.options[1]
    elif y_axis.options:
        y_axis.value = y_axis.options[0]

# Observers
table_dropdown.observe(update_columns, names="value")
plot_type.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Filter by age if it exists
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: {len(df)} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE NEW, SMARTER LOGIC ---
            is_x_categorical = df[x_col].dtype == 'object'
            is_y_categorical = df[y_col].dtype == 'object'

            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            # This is the grouped "Count" chart
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                ages = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        ages + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = ages + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric and X is categorical (e.g., X=gender, Y=age)
            # This is the "Average" chart
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].mean()
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Average {y_col}") # This is an average

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (e.g., X=age, Y=student_id)
            # This is the simple "Count" chart
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            # --- END NEW LOGIC ---
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]), # Kept your original swapped layout
    widgets.HBox([generate_button, best_fit]),
    output
]))

# Trigger the update once to load the first table's columns
if table_names:
    update_columns(None)

VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [1]:
# Simple Data Plot App for Beginners
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display, clear_output

# Connect to database
DB_NAME = "Dataset.db"
conn = sqlite3.connect(DB_NAME)
tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type='table';", conn)
table_names = [t for t in tables["name"].tolist() if t != "sqlite_sequence"]
conn.close()

# Widgets
table_dropdown = widgets.Dropdown(options=table_names, description="Table:")
plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
x_axis = widgets.Dropdown(description="X-Axis:")
y_axis = widgets.Dropdown(description="Y-Axis:")
best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
generate_button = widgets.Button(description="Generate Plot", button_style="success")
output = widgets.Output()

# Update dropdowns when table or plot type changes
def update_columns(change):
    table = table_dropdown.value
    ptype = plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    # Read only a few rows to quickly get column names and types
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        x_axis.options = numeric_cols
        y_axis.options = numeric_cols
        best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        x_axis.options = all_cols
        y_axis.options = all_cols
        best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if x_axis.options:
        x_axis.value = x_axis.options[0]
    if y_axis.options and len(y_axis.options) > 1:
        # Try to set a different default for Y if possible
        y_axis.value = y_axis.options[1]
    elif y_axis.options:
        y_axis.value = y_axis.options[0]

# Observers
table_dropdown.observe(update_columns, names="value")
plot_type.observe(update_columns, names="value")

# Plot when button is clicked
def generate_plot(b):
    with output:
        clear_output(wait=True)
        
        table = table_dropdown.value
        x_col = x_axis.value
        y_col = y_axis.value
        kind = plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # --- Age Filter ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df) # Get count before filtering
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Try to convert columns to numeric where possible
            # This helps our "is_categorical" check
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): # Failed to convert (it's text)
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all(): # Failed to convert (it't text)
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE NEW, SMARTER LOGIC ---
            
            # A column is "categorical" if it's an 'object' (text)
            # OR if it's numeric but has a small number of unique values.
            def is_col_categorical(col_name, unique_thresh=25):
                """Helper function to check if a column should be treated as categorical."""
                # Check if dtype is object (text)
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                # Check if it's numeric and has few unique values
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        # Don't treat a "count" column (like ID) as categorical
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            # --- END NEW LOGIC ---
            
            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            # This is the grouped "Count" chart
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                # Sort by index (X-axis) to ensure it's in order
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric (not categorical) and X is categorical
            # (e.g., X=future_career_concerns, Y=study_load)
            # This is the "Average" (Mean) chart
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].mean()
                # Sort by index (X-axis) to ensure it's in order (e.g., 0, 1, 2, 3...)
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Average {y_col}") # This is an average

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (and not categorical)
            # (e.g., X=age, Y=student_id)
            # This is the simple "Count" chart
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                # Sort by index (X-axis) to ensure it's in order
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            # --- END LOGIC ---
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

generate_button.on_click(generate_plot)

# Display
display(widgets.VBox([
    widgets.HBox([table_dropdown, plot_type]),
    widgets.HBox([x_axis, y_axis]), # Kept your original swapped layout
    widgets.HBox([generate_button, best_fit]),
    output
]))

# Trigger the update once to load the first table's columns
if table_names:
    update_columns(None)

VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological',…

In [1]:
# --- 1. Import All Necessary Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL handler ---
# This assumes SQL_Handler.py is in the same folder as your notebook
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found. Make sure it's in the same directory.")
except Exception as e:
    print(f"Error importing SQL_Handler: {e}")

# --- 2. Shared Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get table names once for all dropdowns
all_table_names = get_table_names()

# ===============================================
# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
# (This section is unchanged)
# ===============================================

manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)
manage_cols_text = widgets.Text(
    description='Columns:',
    value='*'
)
manage_cond_text = widgets.Text(
    description='WHERE:',
    placeholder='e.g., age > 20'
)
manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()

def on_manage_query_click(b):
    """Event handler for the Data Management 'Run Query' button."""
    with manage_output:
        clear_output(wait=True)
        
        # Get values from widgets
        table = manage_table_dropdown.value
        cols = manage_cols_text.value
        cond = manage_cond_text.value or None
        join = manage_join_dropdown.value
        
        if join == "(None)":
            join = None
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            # --- THIS IS WHERE WE USE YOUR SQL_Handler.py ---
            results = SQL_Handler.data_selection(table, cols, cond, join)
            # ------------------------------------------------
            
            if results:
                # Convert list of tuples to DataFrame for nice display
                df = pd.DataFrame(results)
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df)
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")
            print("---")
            print("Make sure your SQL syntax is correct and columns exist.")

manage_button.on_click(on_manage_query_click)

# Layout for the first tab
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    widgets.HBox([manage_cols_text, manage_cond_text]),
    manage_button,
    manage_output
])


# ===============================================
# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
# (This section is replaced with your new code)
# ===============================================

# Widgets (renamed to avoid conflicts)
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

# Update dropdowns when table or plot type changes
def graph_update_columns(change):
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    # Read only a few rows to quickly get column names and types
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        # Try to set a different default for Y if possible
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# Observers
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_update_columns, names="value")

# Plot when button is clicked
def graph_generate_plot(b):
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Age Filter ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df) # Get count before filtering
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Try to convert columns to numeric where possible
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): # Failed to convert (it's text)
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all(): # Failed to convert (it't text)
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE NEW, SMARTER LOGIC ---
            
            # A column is "categorical" if it's an 'object' (text)
            # OR if it's numeric but has a small number of unique values.
            def is_col_categorical(col_name, unique_thresh=25):
                """Helper function to check if a column should be treated as categorical."""
                # Check if dtype is object (text)
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                # Check if it's numeric and has few unique values
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        # Don't treat a "count" column (like ID) as categorical
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            # --- END NEW LOGIC ---
            
            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            # This is the grouped "Count" chart
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                # Sort by index (X-axis) to ensure it's in order
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric (not categorical) and X is categorical
            # (e.g., X=future_career_concerns, Y=study_load)
            # This is the "Average" (Mean) chart
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].mean()
                # Sort by index (X-axis) to ensure it's in order (e.g., 0, 1, 2, 3...)
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Average {y_col}") # This is an average

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (and not categorical)
            # (e.g., X=age, Y=student_id)
            # This is the simple "Count" chart
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                # Sort by index (X-axis) to ensure it's in order
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            # --- END LOGIC ---
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

graph_generate_button.on_click(graph_generate_plot)

# Layout for the second tab
data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output
])

# Trigger the update once to load the first table's columns
if all_table_names:
    graph_update_columns(None)


# ===============================================
# --- 5. ASSEMBLE AND DISPLAY THE TABS ---
# ===============================================

tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, 'Data Management')
tab_container.set_title(1, 'Data Graphing')

display(tab_container)

Successfully imported SQL_Handler.py


Tab(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'P…

In [3]:
# --- 1. Import All Necessary Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL handler ---
# This assumes SQL_Handler.py is in the same folder as your notebook
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found. Make sure it's in the same directory.")
except Exception as e:
    print(f"Error importing SQL_Handler: {e}")

# --- 2. Shared Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get table names once for all dropdowns
all_table_names = get_table_names()

# ===============================================
# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
# ===============================================

manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)
manage_cols_text = widgets.Text(
    description='Columns:',
    value='*'
)
manage_cond_text = widgets.Text(
    description='WHERE:',
    placeholder='e.g., age > 20'
)
manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()

def on_manage_query_click(b):
    """Event handler for the Data Management 'Run Query' button."""
    with manage_output:
        clear_output(wait=True)
        
        # Get values from widgets
        table = manage_table_dropdown.value
        cols = manage_cols_text.value
        cond = manage_cond_text.value or None
        join = manage_join_dropdown.value
        
        if join == "(None)":
            join = None
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            # --- THIS IS WHERE WE USE YOUR SQL_Handler.py ---
            results = SQL_Handler.data_selection(table, cols, cond, join)
            # ------------------------------------------------
            
            if results:
                # Convert list of tuples to DataFrame for nice display
                df = pd.DataFrame(results)
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df)
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")
            print("---")
            print("Make sure your SQL syntax is correct and columns exist.")

manage_button.on_click(on_manage_query_click)

# Layout for the first tab
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    widgets.HBox([manage_cols_text, manage_cond_text]),
    manage_button,
    manage_output
])


# ===============================================
# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
# ===============================================

# Widgets (renamed to avoid conflicts)
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

# Update dropdowns when table or plot type changes
def graph_update_columns(change):
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    # Read only a few rows to quickly get column names and types
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        # Try to set a different default for Y if possible
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# Observers
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_update_columns, names="value")

# Plot when button is clicked
def graph_generate_plot(b):
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Age Filter ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df) # Get count before filtering
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Try to convert columns to numeric where possible
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): # Failed to convert (it's text)
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all(): # Failed to convert (it't text)
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE SMARTER LOGIC ---
            
            def is_col_categorical(col_name, unique_thresh=25):
                """Helper function to check if a column should be treated as categorical."""
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            
            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric (not categorical) and X is categorical
            # (e.g., X=age, Y=student_id)
            # This is now a "Count" chart
            elif is_x_categorical and not is_y_categorical:
                # --- THIS IS THE CHANGED LINE ---
                agg_data = df.groupby(x_col)[y_col].count() 
                # ---------------------------------
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                
                # --- AND THIS IS THE CHANGED LINE ---
                plt.ylabel("Number of Students") 
                # ------------------------------------

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (and not categorical)
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

graph_generate_button.on_click(graph_generate_plot)

# Layout for the second tab
data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output
])

# Trigger the update once to load the first table's columns
if all_table_names:
    graph_update_columns(None)


# ===============================================
# --- 5. ASSEMB

Successfully imported SQL_Handler.py


In [4]:
# --- 1. Import All Necessary Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL handler ---
# This assumes SQL_Handler.py is in the same folder as your notebook
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found. Make sure it's in the same directory.")
except Exception as e:
    print(f"Error importing SQL_Handler: {e}")

# --- 2. Shared Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get table names once for all dropdowns
all_table_names = get_table_names()

# ===============================================
# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
# ===============================================

manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)
manage_cols_text = widgets.Text(
    description='Columns:',
    value='*'
)
manage_cond_text = widgets.Text(
    description='WHERE:',
    placeholder='e.g., age > 20'
)
manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()

def on_manage_query_click(b):
    """Event handler for the Data Management 'Run Query' button."""
    with manage_output:
        clear_output(wait=True)
        
        # Get values from widgets
        table = manage_table_dropdown.value
        cols = manage_cols_text.value
        cond = manage_cond_text.value or None
        join = manage_join_dropdown.value
        
        if join == "(None)":
            join = None
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            # --- THIS IS WHERE WE USE YOUR SQL_Handler.py ---
            # (Make sure you've restarted the kernel after fixing the file)
            results = SQL_Handler.data_selection(table, cols, cond, join)
            # ------------------------------------------------
            
            if results:
                # Convert list of tuples to DataFrame for nice display
                df = pd.DataFrame(results)
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df)
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")
            print("---")
            print("Make sure your SQL syntax is correct and columns exist.")

manage_button.on_click(on_manage_query_click)

# Layout for the first tab
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    widgets.HBox([manage_cols_text, manage_cond_text]),
    manage_button,
    manage_output
])


# ===============================================
# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
# (This section has the corrected .count() logic)
# ===============================================

# Widgets (renamed to avoid conflicts)
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

# Update dropdowns when table or plot type changes
def graph_update_columns(change):
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    # Read only a few rows to quickly get column names and types
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        # Try to set a different default for Y if possible
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# Observers
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_update_columns, names="value")

# Plot when button is clicked
def graph_generate_plot(b):
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Age Filter ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df) # Get count before filtering
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Try to convert columns to numeric where possible
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): # Failed to convert (it's text)
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all(): # Failed to convert (it't text)
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE SMARTER LOGIC ---
            
            def is_col_categorical(col_name, unique_thresh=25):
                """Helper function to check if a column should be treated as categorical."""
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            
            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric (not categorical) and X is categorical
            # (e.g., X=age, Y=student_id)
            # This is now a "Count" chart
            elif is_x_categorical and not is_y_categorical:
                # --- THIS IS THE CORRECTED LINE ---
                agg_data = df.groupby(x_col)[y_col].count() 
                # ----------------------------------
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                
                # --- AND THIS IS THE CORRECTED Y-AXIS LABEL ---
                plt.ylabel("Number of Students") 
                # ----------------------------------------------

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (and not categorical)
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

graph_generate_button.on_click(graph_generate_plot)

# Layout for the second tab
data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output
])

# Trigger the update once to load the first table's columns
if all_table_names:
    graph_update_columns(None)


# ===============================================
# --- 5. ASSEMBLE AND DISPLAY THE TABS ---
# ===============================================

tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, 'Data Management')
tab_container.set_title(1, 'Data Graphing')

display(tab_container)

Successfully imported SQL_Handler.py


Tab(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'P…

In [5]:
# --- 1. Import All Necessary Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL handler ---
# This assumes SQL_Handler.py is in the same folder as your notebook
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found. Make sure it's in the same directory.")
except Exception as e:
    print(f"Error importing SQL_Handler: {e}")

# --- 2. Shared Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    """Fetches all column names for a specific table."""
    if not table_name:
        return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get table names once for all dropdowns
all_table_names = get_table_names()

# ===============================================
# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
# (This section is UPDATED)
# ===============================================

manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)

# --- NEW: Replaced Text box with a SelectMultiple list ---
manage_cols_select = widgets.SelectMultiple(
    description='Columns:',
    options=[],
    value=[],
    rows=6,
    disabled=True # Start disabled
)

manage_cond_text = widgets.Text(
    description='WHERE:',
    placeholder='e.g., age > 20'
)
manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()

# --- NEW: Observer to update the column list ---
def on_manage_table_change(change):
    """Called when the management table dropdown changes."""
    table_name = change['new']
    if table_name:
        columns = get_column_names(table_name)
        manage_cols_select.options = columns
        manage_cols_select.value = columns # Select all by default
        manage_cols_select.disabled = False
    else:
        manage_cols_select.options = []
        manage_cols_select.value = []
        manage_cols_select.disabled = True

# --- MODIFIED: on_manage_query_click ---
def on_manage_query_click(b):
    """Event handler for the Data Management 'Run Query' button."""
    with manage_output:
        clear_output(wait=True)
        
        # Get values from widgets
        table = manage_table_dropdown.value
        
        # --- MODIFIED: Get columns from the new list ---
        cols_list = manage_cols_select.value
        if not cols_list:
            cols = '*' # Default to * if nothing is selected
        else:
            cols = ", ".join(cols_list) # "student_id, age, gender"
        # -----------------------------------------------
            
        cond = manage_cond_text.value or None
        join = manage_join_dropdown.value
        
        if join == "(None)":
            join = None
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            results = SQL_Handler.data_selection(table, cols, cond, join)
            
            if results:
                # --- MODIFIED: Get column names for the DataFrame header ---
                # If cols was '*', we need to fetch them.
                # If cols was "col1, col2", we can just use that.
                if cols == '*':
                    df_cols = get_column_names(table)
                else:
                    df_cols = list(cols_list)
                    
                # If we did a JOIN, we don't know the columns, so just use numbers
                if join:
                    df = pd.DataFrame(results)
                    print("Note: Column headers not available for JOIN queries.")
                else:
                    df = pd.DataFrame(results, columns=df_cols)
                # ------------------------------------------------------------
                
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df)
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")
            print("---")
            print("Make sure your SQL syntax is correct and columns exist.")

# Wire up the new observer
manage_table_dropdown.observe(on_manage_table_change, names='value')
manage_button.on_click(on_manage_query_click)

# Layout for the first tab
# --- MODIFIED: Replaced manage_cols_text with manage_cols_select ---
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    widgets.HBox([manage_cols_select, manage_cond_text]),
    manage_button,
    manage_output
])

# --- NEW: Trigger the column update for the first table ---
if all_table_names:
    on_manage_table_change({'new': all_table_names[0]})


# ===============================================
# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
# (This section is unchanged)
# ===============================================

# Widgets (renamed to avoid conflicts)
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

# Update dropdowns when table or plot type changes
def graph_update_columns(change):
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    # Read only a few rows to quickly get column names and types
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        # Try to set a different default for Y if possible
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# Observers
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_update_columns, names="value")

# Plot when button is clicked
def graph_generate_plot(b):
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Age Filter ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df) # Get count before filtering
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Try to convert columns to numeric where possible
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): # Failed to convert (it's text)
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all(): # Failed to convert (it't text)
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE SMARTER LOGIC ---
            
            def is_col_categorical(col_name, unique_thresh=25):
                """Helper function to check if a column should be treated as categorical."""
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            
            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric (not categorical) and X is categorical
            # (e.g., X=age, Y=student_id)
            # This is a "Count" chart
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].count() 
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel("Number of Students") 

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (and not categorical)
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

graph_generate_button.on_click(graph_generate_plot)

# Layout for the second tab
data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output
])

# Trigger the update once to load the first table's columns
if all_table_names:
    graph_update_columns(None)


# ===============================================
# --- 5. ASSEMBLE AND DISPLAY THE TABS ---
# ===============================================

tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, 'Data Management')
tab_container.set_title(1, 'Data Graphing')

display(tab_container)

Successfully imported SQL_Handler.py


Tab(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'P…

In [6]:
# --- 1. Import All Necessary Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL handler ---
# This assumes SQL_Handler.py is in the same folder as your notebook
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found. Make sure it's in the same directory.")
except Exception as e:
    print(f"Error importing SQL_Handler: {e}")

# --- 2. Shared Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    """Fetches all column names for a specific table."""
    if not table_name:
        return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get table names once for all dropdowns
all_table_names = get_table_names()

# ===============================================
# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
# (This section is HEAVILY UPDATED)
# ===============================================

manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)
manage_cols_select = widgets.SelectMultiple(
    description='Columns:',
    options=[],
    value=[],
    rows=6,
    disabled=True # Start disabled
)

# --- NEW: Replaced the "WHERE" text box with a 3-part filter ---
manage_filter_col_dropdown = widgets.Dropdown(
    options=[],
    description='Filter Column:',
    disabled=True
)
manage_filter_op_dropdown = widgets.Dropdown(
    options=['=', '!=', '>', '<', '>=', '<=', 'LIKE'],
    description='Operator:',
    value='='
)
manage_filter_val_text = widgets.Text(
    description='Value:',
    placeholder='e.g., 20 or "Male"'
)
# --- End of new widgets ---

manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()

# --- UPDATED: This observer now updates *both* the columns list and the filter dropdown ---
def on_manage_table_change(change):
    """Called when the management table dropdown changes."""
    table_name = change['new']
    if table_name:
        columns = get_column_names(table_name)
        
        # Update column multi-select
        manage_cols_select.options = columns
        manage_cols_select.value = columns # Select all by default
        manage_cols_select.disabled = False
        
        # Update filter column dropdown
        manage_filter_col_dropdown.options = ["(None)"] + columns
        manage_filter_col_dropdown.value = "(None)"
        manage_filter_col_dropdown.disabled = False
    else:
        manage_cols_select.options = []
        manage_cols_select.value = []
        manage_cols_select.disabled = True
        
        manage_filter_col_dropdown.options = []
        manage_filter_col_dropdown.value = None
        manage_filter_col_dropdown.disabled = True

# --- MODIFIED: on_manage_query_click ---
def on_manage_query_click(b):
    """Event handler for the Data Management 'Run Query' button."""
    with manage_output:
        clear_output(wait=True)
        
        # Get values from widgets
        table = manage_table_dropdown.value
        
        # --- Get columns from the list ---
        cols_list = manage_cols_select.value
        if not cols_list:
            cols = '*' # Default to * if nothing is selected
        else:
            cols = ", ".join(f'"{c}"' for c in cols_list) # Add quotes for safety
            
        join = manage_join_dropdown.value
        if join == "(None)":
            join = None

        # --- NEW: Build the WHERE condition from the 3-part filter ---
        cond = None
        filter_col = manage_filter_col_dropdown.value
        
        if filter_col != "(None)":
            filter_op = manage_filter_op_dropdown.value
            filter_val = manage_filter_val_text.value
            
            # Smartly add quotes to the value if it's not a number
            # This is a simple check, but covers most cases
            if not filter_val.replace('.','',1).isdigit():
                filter_val = f"'{filter_val}'"
                
            cond = f'"{filter_col}" {filter_op} {filter_val}'
        # --- End of new condition logic ---
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            results = SQL_Handler.data_selection(table, cols, cond, join)
            
            if results:
                # If cols was '*', we need to fetch them.
                if cols == '*':
                    df_cols = get_column_names(table)
                else:
                    df_cols = list(cols_list)
                    
                # If we did a JOIN, we don't know the columns, so just use numbers
                if join:
                    df = pd.DataFrame(results)
                    print("Note: Column headers not available for JOIN queries.")
                else:
                    df = pd.DataFrame(results, columns=df_cols)
                
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df)
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")
            print("---")
            print("Make sure your SQL syntax is correct and columns exist.")

# Wire up the new observer
manage_table_dropdown.observe(on_manage_table_change, names='value')
manage_button.on_click(on_manage_query_click)

# --- NEW: Create a VBox for the new 3-part filter ---
manage_filter_box = widgets.VBox([
    manage_filter_col_dropdown,
    manage_filter_op_dropdown,
    manage_filter_val_text
], layout={'border': '1px solid #CCC', 'padding': '10px', 'margin_top': '10px'})

# --- MODIFIED: The layout for the first tab ---
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    widgets.HBox([manage_cols_select, manage_filter_box]), # Swapped text box for filter box
    manage_button,
    manage_output
])
# --- End of layout modification ---

# --- NEW: Trigger the column update for the first table ---
if all_table_names:
    on_manage_table_change({'new': all_table_names[0]})


# ===============================================
# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
# (This section is unchanged)
# ===============================================

# Widgets (renamed to avoid conflicts)
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

# Update dropdowns when table or plot type changes
def graph_update_columns(change):
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    # Read only a few rows to quickly get column names and types
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        # Try to set a different default for Y if possible
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# Observers
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_update_columns, names="value")

# Plot when button is clicked
def graph_generate_plot(b):
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Age Filter ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df) # Get count before filtering
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Try to convert columns to numeric where possible
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): # Failed to convert (it's text)
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all(): # Failed to convert (it't text)
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE SMARTER LOGIC ---
            
            def is_col_categorical(col_name, unique_thresh=25):
                """Helper function to check if a column should be treated as categorical."""
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            
            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric (not categorical) and X is categorical
            # (e.g., X=age, Y=student_id)
            # This is a "Count" chart
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].count() 
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel("Number of Students") 

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (and not categorical)
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

graph_generate_button.on_click(graph_generate_plot)

# Layout for the second tab
data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output
])

# Trigger the update once to load the first table's columns
if all_table_names:
    graph_update_columns(None)


# ===============================================
# --- 5. ASSEMBLE AND DISPLAY THE TABS ---
# ===============================================

tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, 'Data Management')
tab_container.set_title(1, 'Data Graphing')

display(tab_container)

Successfully imported SQL_Handler.py


Tab(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'P…

In [7]:
# --- 1. Import All Necessary Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL handler ---
# This assumes SQL_Handler.py is in the same folder as your notebook
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found. Make sure it's in the same directory.")
except Exception as e:
    print(f"Error importing SQL_Handler: {e}")

# --- 2. Shared Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    """Fetches all column names for a specific table."""
    if not table_name:
        return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get table names once for all dropdowns
all_table_names = get_table_names()

# ===============================================
# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
# (This section is UPDATED)
# ===============================================

manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)

# --- NEW: Replaced SelectMultiple with a simple Dropdown ---
manage_cols_dropdown = widgets.Dropdown(
    description='Column:',
    options=[],
    value=None,
    disabled=True # Start disabled
)

# --- Filter widgets are unchanged ---
manage_filter_col_dropdown = widgets.Dropdown(
    options=[],
    description='Filter Column:',
    disabled=True
)
manage_filter_op_dropdown = widgets.Dropdown(
    options=['=', '!=', '>', '<', '>=', '<=', 'LIKE'],
    description='Operator:',
    value='='
)
manage_filter_val_text = widgets.Text(
    description='Value:',
    placeholder='e.g., 20 or "Male"'
)
# --- End of filter widgets ---

manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()

# --- UPDATED: This observer updates all 3 dependent dropdowns ---
def on_manage_table_change(change):
    """Called when the management table dropdown changes."""
    table_name = change['new']
    if table_name:
        columns = get_column_names(table_name)
        
        # --- MODIFIED: Update the new Column Dropdown ---
        manage_cols_dropdown.options = ["*"] + columns # Add '*' as first option
        manage_cols_dropdown.value = "*" # Select '*' by default
        manage_cols_dropdown.disabled = False
        
        # Update filter column dropdown
        manage_filter_col_dropdown.options = ["(None)"] + columns
        manage_filter_col_dropdown.value = "(None)"
        manage_filter_col_dropdown.disabled = False
    else:
        manage_cols_dropdown.options = []
        manage_cols_dropdown.value = None
        manage_cols_dropdown.disabled = True
        
        manage_filter_col_dropdown.options = []
        manage_filter_col_dropdown.value = None
        manage_filter_col_dropdown.disabled = True

# --- MODIFIED: on_manage_query_click ---
def on_manage_query_click(b):
    """Event handler for the Data Management 'Run Query' button."""
    with manage_output:
        clear_output(wait=True)
        
        # Get values from widgets
        table = manage_table_dropdown.value
        
        # --- MODIFIED: Get column from the new dropdown ---
        cols = manage_cols_dropdown.value
        if not cols:
            cols = '*' # Failsafe
            
        join = manage_join_dropdown.value
        if join == "(None)":
            join = None

        # --- Build the WHERE condition ---
        cond = None
        filter_col = manage_filter_col_dropdown.value
        
        if filter_col != "(None)":
            filter_op = manage_filter_op_dropdown.value
            filter_val = manage_filter_val_text.value
            
            if not filter_val.replace('.','',1).isdigit():
                filter_val = f"'{filter_val}'"
                
            cond = f'"{filter_col}" {filter_op} {filter_val}'
        # --- End of condition logic ---
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            results = SQL_Handler.data_selection(table, cols, cond, join)
            
            if results:
                # --- Get column names for the DataFrame header ---
                if cols == '*':
                    df_cols = get_column_names(table)
                else:
                    df_cols = [cols] # It's just the one column
                    
                # If we did a JOIN, we don't know the columns
                if join:
                    df = pd.DataFrame(results)
                    print("Note: Column headers not available for JOIN queries.")
                else:
                    df = pd.DataFrame(results, columns=df_cols)
                
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df)
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")
            print("---")
            print("Make sure your SQL syntax is correct and columns exist.")

# Wire up the observer
manage_table_dropdown.observe(on_manage_table_change, names='value')
manage_button.on_click(on_manage_query_click)

# --- Create a VBox for the 3-part filter ---
manage_filter_box = widgets.VBox([
    manage_filter_col_dropdown,
    manage_filter_op_dropdown,
    manage_filter_val_text
], layout={'border': '1px solid #CCC', 'padding': '10px'})

# --- MODIFIED: The layout for the first tab ---
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    # --- Replaced manage_cols_select with manage_cols_dropdown ---
    widgets.HBox([manage_cols_dropdown, manage_filter_box]), 
    manage_button,
    manage_output
])
# --- End of layout modification ---

# --- Trigger the column update for the first table ---
if all_table_names:
    on_manage_table_change({'new': all_table_names[0]})


# ===============================================
# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
# (This section is unchanged)
# ===============================================

# Widgets (renamed to avoid conflicts)
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

# Update dropdowns when table or plot type changes
def graph_update_columns(change):
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    # Read only a few rows to quickly get column names and types
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        # Try to set a different default for Y if possible
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# Observers
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_udate_columns, names="value")

# Plot when button is clicked
def graph_generate_plot(b):
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Age Filter ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df) # Get count before filtering
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Try to convert columns to numeric where possible
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): # Failed to convert (it's text)
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all(): # Failed to convert (it't text)
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE SMARTER LOGIC ---
            
            def is_col_categorical(col

_IncompleteInputError: incomplete input (1272287678.py, line 348)

In [8]:
# --- 1. Import All Necessary Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL handler ---
# This assumes SQL_Handler.py is in the same folder as your notebook
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found. Make sure it's in the same directory.")
except Exception as e:
    print(f"Error importing SQL_Handler: {e}")

# --- 2. Shared Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    """Fetches all column names for a specific table."""
    if not table_name:
        return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get table names once for all dropdowns
all_table_names = get_table_names()

# ===============================================
# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
# (This section is UPDATED)
# ===============================================

manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)

# --- NEW: Replaced SelectMultiple with a simple Dropdown ---
manage_cols_dropdown = widgets.Dropdown(
    description='Column:',
    options=[],
    value=None,
    disabled=True # Start disabled
)

# --- Filter widgets are unchanged ---
manage_filter_col_dropdown = widgets.Dropdown(
    options=[],
    description='Filter Column:',
    disabled=True
)
manage_filter_op_dropdown = widgets.Dropdown(
    options=['=', '!=', '>', '<', '>=', '<=', 'LIKE'],
    description='Operator:',
    value='='
)
manage_filter_val_text = widgets.Text(
    description='Value:',
    placeholder='e.g., 20 or "Male"'
)
# --- End of filter widgets ---

manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()

# --- UPDATED: This observer updates all 3 dependent dropdowns ---
def on_manage_table_change(change):
    """Called when the management table dropdown changes."""
    table_name = change['new']
    if table_name:
        columns = get_column_names(table_name)
        
        # --- MODIFIED: Update the new Column Dropdown ---
        manage_cols_dropdown.options = ["*"] + columns # Add '*' as first option
        manage_cols_dropdown.value = "*" # Select '*' by default
        manage_cols_dropdown.disabled = False
        
        # Update filter column dropdown
        manage_filter_col_dropdown.options = ["(None)"] + columns
        manage_filter_col_dropdown.value = "(None)"
        manage_filter_col_dropdown.disabled = False
    else:
        manage_cols_dropdown.options = []
        manage_cols_dropdown.value = None
        manage_cols_dropdown.disabled = True
        
        manage_filter_col_dropdown.options = []
        manage_filter_col_dropdown.value = None
        manage_filter_col_dropdown.disabled = True

# --- MODIFIED: on_manage_query_click ---
def on_manage_query_click(b):
    """Event handler for the Data Management 'Run Query' button."""
    with manage_output:
        clear_output(wait=True)
        
        # Get values from widgets
        table = manage_table_dropdown.value
        
        # --- MODIFIED: Get column from the new dropdown ---
        cols = manage_cols_dropdown.value
        if not cols:
            cols = '*' # Failsafe
            
        join = manage_join_dropdown.value
        if join == "(None)":
            join = None

        # --- Build the WHERE condition ---
        cond = None
        filter_col = manage_filter_col_dropdown.value
        
        if filter_col != "(None)":
            filter_op = manage_filter_op_dropdown.value
            filter_val = manage_filter_val_text.value
            
            if not filter_val.replace('.','',1).isdigit():
                filter_val = f"'{filter_val}'"
                
            cond = f'"{filter_col}" {filter_op} {filter_val}'
        # --- End of condition logic ---
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            results = SQL_Handler.data_selection(table, cols, cond, join)
            
            if results:
                # --- Get column names for the DataFrame header ---
                if cols == '*':
                    df_cols = get_column_names(table)
                else:
                    df_cols = [cols] # It's just the one column
                    
                # If we did a JOIN, we don't know the columns
                if join:
                    df = pd.DataFrame(results)
                    print("Note: Column headers not available for JOIN queries.")
                else:
                    df = pd.DataFrame(results, columns=df_cols)
                
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df)
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")
            print("---")
            print("Make sure your SQL syntax is correct and columns exist.")

# Wire up the observer
manage_table_dropdown.observe(on_manage_table_change, names='value')
manage_button.on_click(on_manage_query_click)

# --- Create a VBox for the 3-part filter ---
manage_filter_box = widgets.VBox([
    manage_filter_col_dropdown,
    manage_filter_op_dropdown,
    manage_filter_val_text
], layout={'border': '1px solid #CCC', 'padding': '10px'})

# --- MODIFIED: The layout for the first tab ---
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    # --- Replaced manage_cols_select with manage_cols_dropdown ---
    widgets.HBox([manage_cols_dropdown, manage_filter_box]), 
    manage_button,
    manage_output
])
# --- End of layout modification ---

# --- Trigger the column update for the first table ---
if all_table_names:
    on_manage_table_change({'new': all_table_names[0]})


# ===============================================
# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
# (This section is unchanged)
# ===============================================

# Widgets (renamed to avoid conflicts)
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

# Update dropdowns when table or plot type changes
def graph_update_columns(change):
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    # Read only a few rows to quickly get column names and types
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        # Try to set a different default for Y if possible
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# Observers
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_udate_columns, names="value")

# Plot when button is clicked
def graph_generate_plot(b):
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Age Filter ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df) # Get count before filtering
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Try to convert columns to numeric where possible
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): # Failed to convert (it's text)
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all(): # Failed to convert (it't text)
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE SMARTER LOGIC ---
            
            def is_col_categorical(col_name, unique_thresh=25):
                """Helper function to check if a column should be treated as categorical."""
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            
            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric (not categorical) and X is categorical
            # (e.g., X=age, Y=student_id)
            # This is a "Count" chart
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].count() 
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel("Number of Students") 

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (and not categorical)
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

graph_generate_button.on_click(graph_generate_plot)

# Layout for the second tab
data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output
])

# Trigger the update once to load the first table's columns
if all_table_names:
    graph_update_columns(None)


# ===============================================
# --- 5. ASSEMBLE AND DISPLAY THE TABS ---
# ===============================================

tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, 'Data Management')
tab_container.set_title(1, 'Data Graphing')

display(tab_container)

Successfully imported SQL_Handler.py


NameError: name 'graph_udate_columns' is not defined

In [9]:
# --- 1. Import All Necessary Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL handler ---
# This assumes SQL_Handler.py is in the same folder as your notebook
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found. Make sure it's in the same directory.")
except Exception as e:
    print(f"Error importing SQL_Handler: {e}")

# --- 2. Shared Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    """Fetches all column names for a specific table."""
    if not table_name:
        return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get table names once for all dropdowns
all_table_names = get_table_names()

# ===============================================
# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
# (This section is UPDATED)
# ===============================================

manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)

# --- NEW: Replaced SelectMultiple with a simple Dropdown ---
manage_cols_dropdown = widgets.Dropdown(
    description='Column:',
    options=[],
    value=None,
    disabled=True # Start disabled
)

# --- Filter widgets are unchanged ---
manage_filter_col_dropdown = widgets.Dropdown(
    options=[],
    description='Filter Column:',
    disabled=True
)
manage_filter_op_dropdown = widgets.Dropdown(
    options=['=', '!=', '>', '<', '>=', '<=', 'LIKE'],
    description='Operator:',
    value='='
)
manage_filter_val_text = widgets.Text(
    description='Value:',
    placeholder='e.g., 20 or "Male"'
)
# --- End of filter widgets ---

manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()

# --- UPDATED: This observer updates all 3 dependent dropdowns ---
def on_manage_table_change(change):
    """Called when the management table dropdown changes."""
    table_name = change['new']
    if table_name:
        columns = get_column_names(table_name)
        
        # --- MODIFIED: Update the new Column Dropdown ---
        manage_cols_dropdown.options = ["*"] + columns # Add '*' as first option
        manage_cols_dropdown.value = "*" # Select '*' by default
        manage_cols_dropdown.disabled = False
        
        # Update filter column dropdown
        manage_filter_col_dropdown.options = ["(None)"] + columns
        manage_filter_col_dropdown.value = "(None)"
        manage_filter_col_dropdown.disabled = False
    else:
        manage_cols_dropdown.options = []
        manage_cols_dropdown.value = None
        manage_cols_dropdown.disabled = True
        
        manage_filter_col_dropdown.options = []
        manage_filter_col_dropdown.value = None
        manage_filter_col_dropdown.disabled = True

# --- MODIFIED: on_manage_query_click ---
def on_manage_query_click(b):
    """Event handler for the Data Management 'Run Query' button."""
    with manage_output:
        clear_output(wait=True)
        
        # Get values from widgets
        table = manage_table_dropdown.value
        
        # --- MODIFIED: Get column from the new dropdown ---
        cols = manage_cols_dropdown.value
        if not cols:
            cols = '*' # Failsafe
            
        join = manage_join_dropdown.value
        if join == "(None)":
            join = None

        # --- Build the WHERE condition ---
        cond = None
        filter_col = manage_filter_col_dropdown.value
        
        if filter_col != "(None)":
            filter_op = manage_filter_op_dropdown.value
            filter_val = manage_filter_val_text.value
            
            if not filter_val.replace('.','',1).isdigit():
                filter_val = f"'{filter_val}'"
                
            cond = f'"{filter_col}" {filter_op} {filter_val}'
        # --- End of condition logic ---
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            results = SQL_Handler.data_selection(table, cols, cond, join)
            
            if results:
                # --- Get column names for the DataFrame header ---
                if cols == '*':
                    df_cols = get_column_names(table)
                else:
                    df_cols = [cols] # It's just the one column
                    
                # If we did a JOIN, we don't know the columns
                if join:
                    df = pd.DataFrame(results)
                    print("Note: Column headers not available for JOIN queries.")
                else:
                    df = pd.DataFrame(results, columns=df_cols)
                
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df)
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")
            print("---")
            print("Make sure your SQL syntax is correct and columns exist.")

# Wire up the observer
manage_table_dropdown.observe(on_manage_table_change, names='value')
manage_button.on_click(on_manage_query_click)

# --- Create a VBox for the 3-part filter ---
manage_filter_box = widgets.VBox([
    manage_filter_col_dropdown,
    manage_filter_op_dropdown,
    manage_filter_val_text
], layout={'border': '1px solid #CCC', 'padding': '10px'})

# --- MODIFIED: The layout for the first tab ---
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    # --- Replaced manage_cols_select with manage_cols_dropdown ---
    widgets.HBox([manage_cols_dropdown, manage_filter_box]), 
    manage_button,
    manage_output
])
# --- End of layout modification ---

# --- Trigger the column update for the first table ---
if all_table_names:
    on_manage_table_change({'new': all_table_names[0]})


# ===============================================
# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
# (This section is unchanged)
# ===============================================

# Widgets (renamed to avoid conflicts)
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

# Update dropdowns when table or plot type changes
def graph_update_columns(change):
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    # Read only a few rows to quickly get column names and types
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        # Try to set a different default for Y if possible
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# Observers
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_udate_columns, names="value")

# Plot when button is clicked
def graph_generate_plot(b):
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Age Filter ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df) # Get count before filtering
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Try to convert columns to numeric where possible
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): # Failed to convert (it's text)
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all(): # Failed to convert (it't text)
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE SMARTER LOGIC ---
            
            def is_col_categorical(col_name, unique_thresh=25):
                """Helper function to check if a column should be treated as categorical."""
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            
            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric (not categorical) and X is categorical
            # (e.g., X=age, Y=student_id)
            # This is a "Count" chart
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].count() 
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel("Number of Students") 

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (and not categorical)
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

graph_generate_button.on_click(graph_generate_plot)

# Layout for the second tab
data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output
])

# Trigger the update once to load the first table's columns
if all_table_names:
    graph_update_columns(None)


# ===============================================
# --- 5. ASSEMBLE AND DISPLAY THE TABS ---
# ===============================================

tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, 'Data Management')
tab_container.set_title(1, 'Data Graphing')

display(tab_container)

Successfully imported SQL_Handler.py


NameError: name 'graph_udate_columns' is not defined

In [10]:
# --- 1. Import All Necessary Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL handler ---
# This assumes SQL_Handler.py is in the same folder as your notebook
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found. Make sure it's in the same directory.")
except Exception as e:
    print(f"Error importing SQL_Handler: {e}")

# --- 2. Shared Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    """Fetches all column names for a specific table."""
    if not table_name:
        return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get table names once for all dropdowns
all_table_names = get_table_names()

# ===============================================
# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
# ===============================================

manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)

# --- Replaced SelectMultiple with a simple Dropdown ---
manage_cols_dropdown = widgets.Dropdown(
    description='Column:',
    options=[],
    value=None,
    disabled=True # Start disabled
)

# --- Filter widgets ---
manage_filter_col_dropdown = widgets.Dropdown(
    options=[],
    description='Filter Column:',
    disabled=True
)
manage_filter_op_dropdown = widgets.Dropdown(
    options=['=', '!=', '>', '<', '>=', '<=', 'LIKE'],
    description='Operator:',
    value='='
)
manage_filter_val_text = widgets.Text(
    description='Value:',
    placeholder='e.g., 20 or "Male"'
)

manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()

# --- This observer updates all 3 dependent dropdowns ---
def on_manage_table_change(change):
    """Called when the management table dropdown changes."""
    table_name = change['new']
    if table_name:
        columns = get_column_names(table_name)
        
        # Update the new Column Dropdown
        manage_cols_dropdown.options = ["*"] + columns # Add '*' as first option
        manage_cols_dropdown.value = "*" # Select '*' by default
        manage_cols_dropdown.disabled = False
        
        # Update filter column dropdown
        manage_filter_col_dropdown.options = ["(None)"] + columns
        manage_filter_col_dropdown.value = "(None)"
        manage_filter_col_dropdown.disabled = False
    else:
        manage_cols_dropdown.options = []
        manage_cols_dropdown.value = None
        manage_cols_dropdown.disabled = True
        
        manage_filter_col_dropdown.options = []
        manage_filter_col_dropdown.value = None
        manage_filter_col_dropdown.disabled = True

# --- on_manage_query_click ---
def on_manage_query_click(b):
    """Event handler for the Data Management 'Run Query' button."""
    with manage_output:
        clear_output(wait=True)
        
        # Get values from widgets
        table = manage_table_dropdown.value
        
        # Get column from the new dropdown
        cols = manage_cols_dropdown.value
        if not cols:
            cols = '*' # Failsafe
            
        join = manage_join_dropdown.value
        if join == "(None)":
            join = None

        # Build the WHERE condition
        cond = None
        filter_col = manage_filter_col_dropdown.value
        
        if filter_col != "(None)":
            filter_op = manage_filter_op_dropdown.value
            filter_val = manage_filter_val_text.value
            
            if not filter_val.replace('.','',1).isdigit():
                filter_val = f"'{filter_val}'"
                
            cond = f'"{filter_col}" {filter_op} {filter_val}'
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            results = SQL_Handler.data_selection(table, cols, cond, join)
            
            if results:
                # Get column names for the DataFrame header
                if cols == '*':
                    df_cols = get_column_names(table)
                else:
                    df_cols = [cols] # It's just the one column
                    
                if join:
                    df = pd.DataFrame(results)
                    print("Note: Column headers not available for JOIN queries.")
                else:
                    df = pd.DataFrame(results, columns=df_cols)
                
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df)
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")
            print("---")
            print("Make sure your SQL syntax is correct and columns exist.")

# Wire up the observer
manage_table_dropdown.observe(on_manage_table_change, names='value')
manage_button.on_click(on_manage_query_click)

# Create a VBox for the 3-part filter
manage_filter_box = widgets.VBox([
    manage_filter_col_dropdown,
    manage_filter_op_dropdown,
    manage_filter_val_text
], layout={'border': '1px solid #CCC', 'padding': '10px'})

# The layout for the first tab
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    widgets.HBox([manage_cols_dropdown, manage_filter_box]), 
    manage_button,
    manage_output
])

# Trigger the column update for the first table
if all_table_names:
    on_manage_table_change({'new': all_table_names[0]})


# ===============================================
# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
# ===============================================

# Widgets (renamed to avoid conflicts)
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

# Update dropdowns when table or plot type changes
def graph_update_columns(change):
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    # Read only a few rows to quickly get column names and types
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        # Try to set a different default for Y if possible
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# --- OBSERVERS ---
# --- THIS IS THE FIXED LINE ---
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_update_columns, names="value")
# ------------------------------

# Plot when button is clicked
def graph_generate_plot(b):
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Age Filter ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df) # Get count before filtering
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Try to convert columns to numeric where possible
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): # Failed to convert (it's text)
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all(): # Failed to convert (it't text)
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE SMARTER LOGIC ---
            
            def is_col_categorical(col_name, unique_thresh=25):
                """Helper function to check if a column should be treated as categorical."""
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            
            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric (not categorical) and X is categorical
            # (e.g., X=age, Y=student_id)
            # This is a "Count" chart
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].count() 
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel("Number of Students") 

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (and not categorical)
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

graph_generate_button.on_click(graph_generate_plot)

# Layout for the second tab
data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output
])

# Trigger the update once to load the first table's columns
if all_table_names:
    graph_update_columns(None)


# ===============================================
# --- 5. ASSEMBLE AND DISPLAY THE TABS ---
# ===============================================

tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, 'Data Management')
tab_container.set_title(1, 'Data Graphing')

display(tab_container)

Successfully imported SQL_Handler.py


Tab(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'P…

In [1]:
# --- 1. Import All Necessary Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL handler ---
# This assumes SQL_Handler.py is in the same folder as your notebook
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found. Make sure it's in the same directory.")
except Exception as e:
    print(f"Error importing SQL_Handler: {e}")

# --- 2. Shared Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    """Fetches all column names for a specific table."""
    if not table_name:
        return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get table names once for all dropdowns
all_table_names = get_table_names()

# ===============================================
# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
# ===============================================

manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)

# --- Replaced SelectMultiple with a simple Dropdown ---
manage_cols_dropdown = widgets.Dropdown(
    description='Column:',
    options=[],
    value=None,
    disabled=True # Start disabled
)

# --- Filter widgets ---
manage_filter_col_dropdown = widgets.Dropdown(
    options=[],
    description='Filter Column:',
    disabled=True
)
manage_filter_op_dropdown = widgets.Dropdown(
    options=['=', '!=', '>', '<', '>=', '<=', 'LIKE'],
    description='Operator:',
    value='='
)
manage_filter_val_text = widgets.Text(
    description='Value:',
    placeholder='e.g., 20 or "Male"'
)

manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()

# --- This observer updates all 3 dependent dropdowns ---
def on_manage_table_change(change):
    """Called when the management table dropdown changes."""
    table_name = change['new']
    if table_name:
        columns = get_column_names(table_name)
        
        # Update the new Column Dropdown
        manage_cols_dropdown.options = ["*"] + columns # Add '*' as first option
        manage_cols_dropdown.value = "*" # Select '*' by default
        manage_cols_dropdown.disabled = False
        
        # Update filter column dropdown
        manage_filter_col_dropdown.options = ["(None)"] + columns
        manage_filter_col_dropdown.value = "(None)"
        manage_filter_col_dropdown.disabled = False
    else:
        manage_cols_dropdown.options = []
        manage_cols_dropdown.value = None
        manage_cols_dropdown.disabled = True
        
        manage_filter_col_dropdown.options = []
        manage_filter_col_dropdown.value = None
        manage_filter_col_dropdown.disabled = True

# --- on_manage_query_click ---
def on_manage_query_click(b):
    """Event handler for the Data Management 'Run Query' button."""
    with manage_output:
        clear_output(wait=True)
        
        # Get values from widgets
        table = manage_table_dropdown.value
        
        # Get column from the new dropdown
        cols = manage_cols_dropdown.value
        if not cols:
            cols = '*' # Failsafe
            
        join = manage_join_dropdown.value
        if join == "(None)":
            join = None

        # Build the WHERE condition
        cond = None
        filter_col = manage_filter_col_dropdown.value
        
        if filter_col != "(None)":
            filter_op = manage_filter_op_dropdown.value
            filter_val = manage_filter_val_text.value
            
            if not filter_val.replace('.','',1).isdigit():
                filter_val = f"'{filter_val}'"
                
            cond = f'"{filter_col}" {filter_op} {filter_val}'
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            results = SQL_Handler.data_selection(table, cols, cond, join)
            
            if results:
                # Get column names for the DataFrame header
                if cols == '*':
                    df_cols = get_column_names(table)
                else:
                    df_cols = [cols] # It's just the one column
                    
                if join:
                    df = pd.DataFrame(results)
                    print("Note: Column headers not available for JOIN queries.")
                else:
                    df = pd.DataFrame(results, columns=df_cols)
                
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df)
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")
            print("---")
            print("Make sure your SQL syntax is correct and columns exist.")

# Wire up the observer
manage_table_dropdown.observe(on_manage_table_change, names='value')
manage_button.on_click(on_manage_query_click)

# Create a VBox for the 3-part filter
manage_filter_box = widgets.VBox([
    manage_filter_col_dropdown,
    manage_filter_op_dropdown,
    manage_filter_val_text
], layout={'border': '1px solid #CCC', 'padding': '10px'})

# The layout for the first tab
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    widgets.HBox([manage_cols_dropdown, manage_filter_box]), 
    manage_button,
    manage_output
])

# Trigger the column update for the first table
if all_table_names:
    on_manage_table_change({'new': all_table_names[0]})


# ===============================================
# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
# ===============================================

# Widgets (renamed to avoid conflicts)
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

# Update dropdowns when table or plot type changes
def graph_update_columns(change):
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    # Read only a few rows to quickly get column names and types
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        # Try to set a different default for Y if possible
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# --- OBSERVERS ---
# --- THIS IS THE FIXED LINE ---
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_update_columns, names="value")
# ------------------------------

# Plot when button is clicked
def graph_generate_plot(b):
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Age Filter ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df) # Get count before filtering
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Try to convert columns to numeric where possible
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): # Failed to convert (it's text)
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all(): # Failed to convert (it't text)
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE SMARTER LOGIC ---
            
            def is_col_categorical(col_name, unique_thresh=25):
                """Helper function to check if a column should be treated as categorical."""
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            
            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric (not categorical) and X is categorical
            # (e.g., X=age, Y=student_id)
            # This is a "Count" chart
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].count() 
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel("Number of Students") 

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (and not categorical)
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

graph_generate_button.on_click(graph_generate_plot)

# Layout for the second tab
data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output
])

# Trigger the update once to load the first table's columns
if all_table_names:
    graph_update_columns(None)


# ===============================================
# --- 5. ASSEMBLE AND DISPLAY THE TABS ---
# ===============================================

tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, 'Data Management')
tab_container.set_title(1, 'Data Graphing')

display(tab_container)

Successfully imported SQL_Handler.py


Tab(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'P…

In [2]:
# --- 1. Import All Necessary Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL handler ---
# This assumes SQL_Handler.py is in the same folder as your notebook
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found. Make sure it's in the same directory.")
except Exception as e:
    print(f"Error importing SQL_Handler: {e}")

# --- 2. Shared Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    """Fetches all column names for a specific table."""
    if not table_name:
        return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get table names once for all dropdowns
all_table_names = get_table_names()

# ===============================================
# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
# ===============================================

manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)

# --- Replaced SelectMultiple with a simple Dropdown ---
manage_cols_dropdown = widgets.Dropdown(
    description='Column:',
    options=[],
    value=None,
    disabled=True # Start disabled
)

# --- Filter widgets ---
manage_filter_col_dropdown = widgets.Dropdown(
    options=[],
    description='Filter Column:',
    disabled=True
)
manage_filter_op_dropdown = widgets.Dropdown(
    options=['=', '!=', '>', '<', '>=', '<=', 'LIKE'],
    description='Operator:',
    value='='
)
# --- THIS IS THE UPDATED LINE ---
manage_filter_val_text = widgets.Text(
    description='Value:',
    placeholder="e.g., 20 or 'M'"
)
# ---------------------------------

manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()

# --- This observer updates all 3 dependent dropdowns ---
def on_manage_table_change(change):
    """Called when the management table dropdown changes."""
    table_name = change['new']
    if table_name:
        columns = get_column_names(table_name)
        
        # Update the new Column Dropdown
        manage_cols_dropdown.options = ["*"] + columns # Add '*' as first option
        manage_cols_dropdown.value = "*" # Select '*' by default
        manage_cols_dropdown.disabled = False
        
        # Update filter column dropdown
        manage_filter_col_dropdown.options = ["(None)"] + columns
        manage_filter_col_dropdown.value = "(None)"
        manage_filter_col_dropdown.disabled = False
    else:
        manage_cols_dropdown.options = []
        manage_cols_dropdown.value = None
        manage_cols_dropdown.disabled = True
        
        manage_filter_col_dropdown.options = []
        manage_filter_col_dropdown.value = None
        manage_filter_col_dropdown.disabled = True

# --- on_manage_query_click ---
def on_manage_query_click(b):
    """Event handler for the Data Management 'Run Query' button."""
    with manage_output:
        clear_output(wait=True)
        
        # Get values from widgets
        table = manage_table_dropdown.value
        
        # Get column from the new dropdown
        cols = manage_cols_dropdown.value
        if not cols:
            cols = '*' # Failsafe
            
        join = manage_join_dropdown.value
        if join == "(None)":
            join = None

        # Build the WHERE condition
        cond = None
        filter_col = manage_filter_col_dropdown.value
        
        if filter_col != "(None)":
            filter_op = manage_filter_op_dropdown.value
            filter_val = manage_filter_val_text.value
            
            # Smartly add quotes to the value if it's not a number
            if not filter_val.replace('.','',1).isdigit():
                filter_val = f"'{filter_val}'"
                
            cond = f'"{filter_col}" {filter_op} {filter_val}'
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            results = SQL_Handler.data_selection(table, cols, cond, join)
            
            if results:
                # Get column names for the DataFrame header
                if cols == '*':
                    df_cols = get_column_names(table)
                else:
                    df_cols = [cols] # It's just the one column
                    
                if join:
                    df = pd.DataFrame(results)
                    print("Note: Column headers not available for JOIN queries.")
                else:
                    df = pd.DataFrame(results, columns=df_cols)
                
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df)
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")
            print("---")
            print("Make sure your SQL syntax is correct and columns exist.")

# Wire up the observer
manage_table_dropdown.observe(on_manage_table_change, names='value')
manage_button.on_click(on_manage_query_click)

# Create a VBox for the 3-part filter
manage_filter_box = widgets.VBox([
    manage_filter_col_dropdown,
    manage_filter_op_dropdown,
    manage_filter_val_text
], layout={'border': '1px solid #CCC', 'padding': '10px'})

# The layout for the first tab
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    widgets.HBox([manage_cols_dropdown, manage_filter_box]), 
    manage_button,
    manage_output
])

# Trigger the column update for the first table
if all_table_names:
    on_manage_table_change({'new': all_table_names[0]})


# ===============================================
# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
# ===============================================

# Widgets (renamed to avoid conflicts)
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

# Update dropdowns when table or plot type changes
def graph_update_columns(change):
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value # Get the plot type
    
    conn = sqlite3.connect(DB_NAME)
    # Read only a few rows to quickly get column names and types
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # SCATTER: Only show numeric columns
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # BAR: Show all columns
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        # Try to set a different default for Y if possible
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# --- OBSERVERS ---
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_update_columns, names="value")

# Plot when button is clicked
def graph_generate_plot(b):
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Age Filter ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df) # Get count before filtering
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            # We round 'age' *before* plotting
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter plot ---
        if kind == "Scatter":
            # Convert to numeric and drop NAs
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close() # Close the empty figure
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[x_col], df[y_col], 1)
                    plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar plot ---
        else:
            # Try to convert columns to numeric where possible
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): # Failed to convert (it's text)
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all(): # Failed to convert (it't text)
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col]) # Drop NAs here
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close() # Close the empty figure
                return
            
            # --- THIS IS THE SMARTER LOGIC ---
            
            def is_col_categorical(col_name, unique_thresh=25):
                """Helper function to check if a column should be treated as categorical."""
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            
            # Case 1: Y is categorical (e.g., X=age, Y=gender)
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count") # This is a count

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: Y is numeric (not categorical) and X is categorical
            # (e.g., X=age, Y=student_id)
            # This is a "Count" chart
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].count() 
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel("Number of Students") 

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both X and Y are numeric (and not categorical)
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}") # This is a count

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

graph_generate_button.on_click(graph_generate_plot)

# Layout for the second tab
data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output
])

# Trigger the update once to load the first table's columns
if all_table_names:
    graph_update_columns(None)


# ===============================================
# --- 5. ASSEMBLE AND DISPLAY THE TABS ---
# ===============================================

tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, 'Data Management')
tab_container.set_title(1, 'Data Graphing')

display(tab_container)

Successfully imported SQL_Handler.py


Tab(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'P…

In [3]:
# ===============================================================
# 1. IMPORT LIBRARIES (BEGINNER FRIENDLY)
# ===============================================================

import sqlite3          # lets us connect to a database
import pandas as pd     # helps show tables nicely
import matplotlib.pyplot as plt   # lets us make charts
import ipywidgets as widgets       # lets us make dropdowns/buttons
from IPython.display import display, clear_output

import numpy as np      # used for best-fit line math

# Try to import your SQL handler file
try:
    import SQL_Handler
    print("Loaded SQL_Handler.py")
except:
    print("Could NOT load SQL_Handler.py")

DB_NAME = "Dataset.db"



# ===============================================================
# 2. HELPER FUNCTIONS (BEGINNER FRIENDLY)
# ===============================================================

# Get list of tables
def get_tables():
    connection = sqlite3.connect(DB_NAME)
    cursor = connection.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    table_names = [row[0] for row in cursor.fetchall() if row[0] != "sqlite_sequence"]
    connection.close()
    return table_names

# Get list of columns in a table
def get_columns(table):
    connection = sqlite3.connect(DB_NAME)
    cursor = connection.cursor()
    cursor.execute(f"PRAGMA table_info('{table}')")
    columns = [row[1] for row in cursor.fetchall()]
    connection.close()
    return columns



# ===============================================================
# 3. CREATE WIDGETS FOR DATA MANAGEMENT TAB
# ===============================================================

table_dropdown = widgets.Dropdown(description="Table:")
join_dropdown = widgets.Dropdown(description="Join:")
column_dropdown = widgets.Dropdown(description="Column:")
filter_column_dropdown = widgets.Dropdown(description="Filter Column:")
filter_operator_dropdown = widgets.Dropdown(
    description="Op:",
    options=["=", "!=", ">", "<", ">=", "<=", "LIKE"]
)
filter_value_text = widgets.Text(description="Value:")
run_button = widgets.Button(description="Run Query")
output_box = widgets.Output()

# Fill table names
tables = get_tables()
table_dropdown.options = tables
join_dropdown.options = ["(None)"] + tables



# ===============================================================
# When the user picks a table, update column options
# ===============================================================
def update_columns(change):
    table = table_dropdown.value
    if table:
        cols = get_columns(table)
        column_dropdown.options = ["*"] + cols
        filter_column_dropdown.options = ["(None)"] + cols

table_dropdown.observe(update_columns, names="value")



# ===============================================================
# When user clicks RUN QUERY button
# ===============================================================
def run_query(button):
    with output_box:
        clear_output()

        table = table_dropdown.value
        column = column_dropdown.value
        join_table = join_dropdown.value
        filter_col = filter_column_dropdown.value
        op = filter_operator_dropdown.value
        val = filter_value_text.value

        # Build condition string
        condition = None
        if filter_col != "(None)":
            # Add quotes if value is text
            if val.replace(".", "", 1).isdigit():
                condition = f"{filter_col} {op} {val}"
            else:
                condition = f"{filter_col} {op} '{val}'"

        # Build join
        if join_table == "(None)":
            join_table = None

        print("Running Query...")

        try:
            results = SQL_Handler.data_selection(
                table,
                column,
                condition,
                join_table
            )

            if results:
                print("Query Successful. Rows Found:", len(results))

                if column == "*":
                    col_names = get_columns(table)
                    df = pd.DataFrame(results, columns=col_names)
                else:
                    df = pd.DataFrame(results)

                display(df)
            else:
                print("No results found.")

        except Exception as e:
            print("Error:", e)

run_button.on_click(run_query)



# ===============================================================
# 4. DATA GRAPHING TAB (BEGINNER FRIENDLY)
# ===============================================================

graph_table_dropdown = widgets.Dropdown(
    description="Table:",
    options=tables
)
plot_type_dropdown = widgets.Dropdown(
    description="Plot:",
    options=["Bar", "Scatter"]
)
x_dropdown = widgets.Dropdown(description="X:")
y_dropdown = widgets.Dropdown(description="Y:")
best_fit_checkbox = widgets.Checkbox(description="Best Fit Line", value=True)
plot_button = widgets.Button(description="Make Plot", button_style="success")
plot_output = widgets.Output()



# ===============================================================
# When the user picks a table OR plot type, update columns
# ===============================================================
def update_graph_columns(change):
    table = graph_table_dropdown.value

    try:
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        cols = df.columns.tolist()
        x_dropdown.options = cols
        y_dropdown.options = cols

        if plot_type_dropdown.value == "Scatter":
            numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
            x_dropdown.options = numeric_cols
            y_dropdown.options = numeric_cols
            best_fit_checkbox.layout.display = "flex"
        else:
            best_fit_checkbox.layout.display = "none"

    except:
        x_dropdown.options = []
        y_dropdown.options = []

graph_table_dropdown.observe(update_graph_columns, names="value")
plot_type_dropdown.observe(update_graph_columns, names="value")
update_graph_columns(None)



# ===============================================================
# When user clicks "Make Plot"
# ===============================================================
def make_plot(button):
    with plot_output:
        clear_output()

        table = graph_table_dropdown.value
        xcol = x_dropdown.value
        ycol = y_dropdown.value
        ptype = plot_type_dropdown.value

        if not xcol or not ycol:
            print("Pick X and Y columns first.")
            return

        # Load data
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # Plot
        plt.figure(figsize=(7, 4))

        if ptype == "Scatter":
            df[xcol] = pd.to_numeric(df[xcol], errors="coerce")
            df[ycol] = pd.to_numeric(df[ycol], errors="coerce")
            df = df.dropna(subset=[xcol, ycol])

            plt.scatter(df[xcol], df[ycol])

            # Best fit line
            if best_fit_checkbox.value:
                m, b = np.polyfit(df[xcol], df[ycol], 1)
                plt.plot(df[xcol], m*df[xcol] + b)

        else:  # Bar chart
            try:
                counts = df.groupby(xcol)[ycol].count()
                counts.plot(kind="bar")
            except:
                print("Cannot make bar chart with selected columns.")
                return

        plt.title(f"{ycol} vs {xcol}")
        plt.xlabel(xcol)
        plt.ylabel(ycol)
        plt.tight_layout()
        plt.show()

plot_button.on_click(make_plot)



# ===============================================================
# 5. ADD BOTH TABS
# ===============================================================

data_tab = widgets.VBox([
    table_dropdown,
    join_dropdown,
    column_dropdown,
    filter_column_dropdown,
    filter_operator_dropdown,
    filter_value_text,
    run_button,
    output_box
])

graph_tab = widgets.VBox([
    graph_table_dropdown,
    plot_type_dropdown,
    x_dropdown,
    y_dropdown,
    best_fit_checkbox,
    plot_button,
    plot_output
])

tabs = widgets.Tab()
tabs.children = [data_tab, graph_tab]
tabs.set_title(0, "Data Management")
tabs.set_title(1, "Graphing")

display(tabs)


Loaded SQL_Handler.py


Tab(children=(VBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological', …

In [4]:
# ===============================================================
# 1. IMPORT LIBRARIES (BEGINNER FRIENDLY)
# ===============================================================

import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import numpy as np

# Try to import SQL handler
try:
    import SQL_Handler
    print("Loaded SQL_Handler.py")
except:
    print("Could NOT load SQL_Handler.py")

DB_NAME = "Dataset.db"



# ===============================================================
# 2. HELPER FUNCTIONS
# ===============================================================

def get_tables():
    conn = sqlite3.connect(DB_NAME)
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    names = [row[0] for row in cursor.fetchall() if row[0] != "sqlite_sequence"]
    conn.close()
    return names

def get_columns(table):
    conn = sqlite3.connect(DB_NAME)
    cursor = conn.cursor()
    cursor.execute(f"PRAGMA table_info('{table}')")
    cols = [row[1] for row in cursor.fetchall()]
    conn.close()
    return cols



# ===============================================================
# 3. DATA MANAGEMENT TAB WIDGETS
# ===============================================================

table_dropdown = widgets.Dropdown(description="Table:")
join_dropdown = widgets.Dropdown(description="Join:")
column_dropdown = widgets.Dropdown(description="Column:")
filter_column_dropdown = widgets.Dropdown(description="Filter Col:")
filter_operator_dropdown = widgets.Dropdown(
    description="Op:",
    options=["=", "!=", ">", "<", ">=", "<=", "LIKE"]
)
filter_value_text = widgets.Text(description="Value:")
run_button = widgets.Button(description="Run Query")
output_box = widgets.Output()

tables = get_tables()
table_dropdown.options = tables
join_dropdown.options = ["(None)"] + tables



# ===============================================================
# Update columns when table changes
# ===============================================================

def update_columns(change):
    table = table_dropdown.value
    if table:
        cols = get_columns(table)
        column_dropdown.options = ["*"] + cols
        filter_column_dropdown.options = ["(None)"] + cols

table_dropdown.observe(update_columns, names="value")



# ===============================================================
# RUN QUERY BUTTON
# ===============================================================

def run_query(button):
    with output_box:
        clear_output()

        table = table_dropdown.value
        column = column_dropdown.value
        join_table = join_dropdown.value
        filter_col = filter_column_dropdown.value
        op = filter_operator_dropdown.value
        val = filter_value_text.value

        # Build filter (WHERE)
        condition = None
        if filter_col != "(None)":
            if val.replace(".", "", 1).isdigit():
                condition = f"{filter_col} {op} {val}"
            else:
                condition = f"{filter_col} {op} '{val}'"

        if join_table == "(None)":
            join_table = None

        print("Running Query...")

        try:
            results = SQL_Handler.data_selection(
                table,
                column,
                condition,
                join_table
            )

            if results:
                # Turn into DataFrame
                if column == "*":
                    df = pd.DataFrame(results, columns=get_columns(table))
                else:
                    df = pd.DataFrame(results)

                # -----------------------------
                # ⭐ FILTER AGE 18–22 ONLY ⭐
                # -----------------------------
                if "age" in df.columns:
                    df["age"] = pd.to_numeric(df["age"], errors="coerce")
                    df = df.dropna(subset=["age"])
                    df = df[(df["age"] >= 18) & (df["age"] <= 22)]

                print("Rows after age filter:", len(df))
                display(df)

            else:
                print("No results.")

        except Exception as e:
            print("Error:", e)

run_button.on_click(run_query)



# ===============================================================
# 4. GRAPHING TAB WIDGETS
# ===============================================================

graph_table_dropdown = widgets.Dropdown(description="Table:", options=tables)
plot_type_dropdown = widgets.Dropdown(description="Plot:", options=["Bar", "Scatter"])
x_dropdown = widgets.Dropdown(description="X:")
y_dropdown = widgets.Dropdown(description="Y:")
best_fit_checkbox = widgets.Checkbox(description="Best Fit", value=True)
plot_button = widgets.Button(description="Plot", button_style="success")
plot_output = widgets.Output()



# ===============================================================
# Update graph dropdowns
# ===============================================================

def update_graph_columns(change):
    table = graph_table_dropdown.value
    try:
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        cols = df.columns.tolist()

        if plot_type_dropdown.value == "Scatter":
            numeric = df.select_dtypes(include=np.number).columns.tolist()
            x_dropdown.options = numeric
            y_dropdown.options = numeric
            best_fit_checkbox.layout.display = "flex"
        else:
            x_dropdown.options = cols
            y_dropdown.options = cols
            best_fit_checkbox.layout.display = "none"

    except:
        x_dropdown.options = []
        y_dropdown.options = []

graph_table_dropdown.observe(update_graph_columns, names="value")
plot_type_dropdown.observe(update_graph_columns, names="value")
update_graph_columns(None)



# ===============================================================
# MAKE PLOT
# ===============================================================

def make_plot(button):
    with plot_output:
        clear_output()

        table = graph_table_dropdown.value
        xcol = x_dropdown.value
        ycol = y_dropdown.value
        kind = plot_type_dropdown.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        # -----------------------------
        # ⭐ FILTER AGE 18–22 ONLY ⭐
        # -----------------------------
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            df = df.dropna(subset=["age"])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]

        if df.empty:
            print("No students age 18–22 found.")
            return

        plt.figure(figsize=(7, 4))

        if kind == "Scatter":
            df[xcol] = pd.to_numeric(df[xcol], errors="coerce")
            df[ycol] = pd.to_numeric(df[ycol], errors="coerce")
            df = df.dropna(subset=[xcol, ycol])

            plt.scatter(df[xcol], df[ycol])

            if best_fit_checkbox.value:
                m, b = np.polyfit(df[xcol], df[ycol], 1)
                plt.plot(df[xcol], m*df[xcol] + b)

        else:  # bar chart
            try:
                counts = df.groupby(xcol)[ycol].count()
                counts.plot(kind="bar")
            except:
                print("Cannot plot with selected columns.")
                return

        plt.title(f"{ycol} vs {xcol}")
        plt.xlabel(xcol)
        plt.ylabel(ycol)
        plt.tight_layout()
        plt.show()

plot_button.on_click(make_plot)



# ===============================================================
# 5. CREATE TABS
# ===============================================================

data_tab = widgets.VBox([
    table_dropdown,
    join_dropdown,
    column_dropdown,
    filter_column_dropdown,
    filter_operator_dropdown,
    filter_value_text,
    run_button,
    output_box
])

graph_tab = widgets.VBox([
    graph_table_dropdown,
    plot_type_dropdown,
    x_dropdown,
    y_dropdown,
    best_fit_checkbox,
    plot_button,
    plot_output
])

tabs = widgets.Tab()
tabs.children = [data_tab, graph_tab]
tabs.set_title(0, "Data Management")
tabs.set_title(1, "Graphing")

display(tabs)


Loaded SQL_Handler.py


Tab(children=(VBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'Psychological', …

In [5]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL_Handler.py file ---
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found.")

# --- 2. Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    """Fetches all column names for a specific table."""
    if not table_name:
        return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get a list of tables to use in our dropdowns
all_table_names = get_table_names()

# ===============================================
# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
# ===============================================

manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)
manage_cols_dropdown = widgets.Dropdown(
    description='Column:',
    options=[],
    value=None,
    disabled=True 
)
manage_filter_col_dropdown = widgets.Dropdown(
    options=[],
    description='Filter Column:',
    disabled=True
)
manage_filter_op_dropdown = widgets.Dropdown(
    options=['=', '!=', '>', '<', '>=', '<=', 'LIKE'],
    description='Operator:',
    value='='
)
manage_filter_val_text = widgets.Text(
    description='Value:',
    placeholder="e.g., 20 or 'M'"
)
manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()


def on_manage_table_change(change):
    """Called when the 'Table' dropdown changes."""
    table_name = change['new']
    if table_name:
        columns = get_column_names(table_name)
        
        # Update the Column dropdown
        manage_cols_dropdown.options = ["*"] + columns
        manage_cols_dropdown.value = "*"
        manage_cols_dropdown.disabled = False
        
        # Update the Filter dropdown
        manage_filter_col_dropdown.options = ["(None)"] + columns
        manage_filter_col_dropdown.value = "(None)"
        manage_filter_col_dropdown.disabled = False
    else:
        # Disable dropdowns if no table is selected
        manage_cols_dropdown.options = []
        manage_cols_dropdown.value = None
        manage_cols_dropdown.disabled = True
        
        manage_filter_col_dropdown.options = []
        manage_filter_col_dropdown.value = None
        manage_filter_col_dropdown.disabled = True


def on_manage_query_click(b):
    """Called when the 'Run Query' button is clicked."""
    with manage_output:
        clear_output(wait=True)
        
        table = manage_table_dropdown.value
        cols = manage_cols_dropdown.value
        
        join = manage_join_dropdown.value
        if join == "(None)":
            join = None

        # Build the "WHERE" condition string
        cond = None
        filter_col = manage_filter_col_dropdown.value
        
        if filter_col != "(None)":
            filter_op = manage_filter_op_dropdown.value
            filter_val = manage_filter_val_text.value
            
            # Add quotes for text (like 'M') but not for numbers (like 20)
            if not filter_val.replace('.','',1).isdigit():
                filter_val = f"'{filter_val}'"
                
            cond = f'"{filter_col}" {filter_op} {filter_val}'
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            # Run the query using our SQL_Handler.py file
            results = SQL_Handler.data_selection(table, cols, cond, join)
            
            if results:
                # Get column headers for the results table
                if cols == '*':
                    df_cols = get_column_names(table)
                else:
                    df_cols = [cols] 
                    
                if join:
                    df = pd.DataFrame(results)
                    print("Note: Column headers not available for JOIN queries.")
                else:
                    df = pd.DataFrame(results, columns=df_cols)
                
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df) # Display the results as a table
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")

# Tell the dropdown to run our function when its value changes
manage_table_dropdown.observe(on_manage_table_change, names='value')

# Tell the button to run our function when it's clicked
manage_button.on_click(on_manage_query_click)

# This VBox holds the 3-part filter
manage_filter_box = widgets.VBox([
    manage_filter_col_dropdown,
    manage_filter_op_dropdown,
    manage_filter_val_text
], layout={'border': '1px solid #CCC', 'padding': '10px'})

# This VBox holds all the widgets for the first tab
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    widgets.HBox([manage_cols_dropdown, manage_filter_box]), 
    manage_button,
    manage_output
])

# Load the columns for the first table on startup
if all_table_names:
    on_manage_table_change({'new': all_table_names[0]})


# ===============================================
# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
# ===============================================

graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

def graph_update_columns(change):
    """Called when the graphing 'Table' or 'Plot Type' dropdowns change."""
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value
    
    conn = sqlite3.connect(DB_NAME)
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # Scatter plots only make sense with numbers
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # Bar plots can use any column
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# Tell the dropdowns to run our function when they change
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_update_columns, names="value")

def graph_generate_plot(b):
    """Called when the 'Generate Plot' button is clicked."""
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Filter by Age (18-22) ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df)
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter Plot ---
        if kind == "Scatter":
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) 

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close()
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value and len(df) > 1:
                m, b = np.polyfit(df[x_col], df[y_col], 1)
                plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar Plot ---
        else:
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): 
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all():
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col])
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close()
                return
            
            # Helper function to decide if a column is "categorical" (text or few numbers)
            def is_col_categorical(col_name, unique_thresh=25):
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            
            # Case 1: Y-axis is categorical (e.g., X=age, Y=gender)
            # This makes a "Grouped Count" chart
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count")

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: X-axis is categorical, Y-axis is numeric (e.g., X=age, Y=student_id)
            # This makes a "Count" chart
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].count() 
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel("Number of Students") 

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both axes are numeric (and not categorical)
            # This makes a simple "Count" chart
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}")

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show() # Display the plot

# Tell the button to run our function when it's clicked
graph_generate_button.on_click(graph_generate_plot)

# This VBox holds all the widgets for the second tab
data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output
])

# Load the columns for the first table on startup
if all_table_names:
    graph_update_columns(None)


# ===============================================
# --- 5. ASSEMBLE AND DISPLAY THE TABS ---
# ===============================================

tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, 'Data Management')
tab_container.set_title(1, 'Data Graphing')

# This is the final step: display the tabs!
display(tab_container)

Successfully imported SQL_Handler.py


Tab(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'P…

In [3]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL_Handler.py file ---
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found.")

# --- 2. Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    """Fetches all column names for a specific table."""
    if not table_name:
        return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get a list of tables to use in our dropdowns
all_table_names = get_table_names()


# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)
manage_cols_dropdown = widgets.Dropdown(
    description='Column:',
    options=[],
    value=None,
    disabled=True 
)
manage_filter_col_dropdown = widgets.Dropdown(
    options=[],
    description='Filter Column:',
    disabled=True
)
manage_filter_op_dropdown = widgets.Dropdown(
    options=['=', '!=', '>', '<', '>=', '<=', 'LIKE'],
    description='Operator:',
    value='='
)
manage_filter_val_text = widgets.Text(
    description='Value:',
    placeholder="e.g., 20 or 'M'"
)
manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()


def on_manage_table_change(change):
    """Called when the 'Table' dropdown changes."""
    table_name = change['new']
    if table_name:
        columns = get_column_names(table_name)
        
        # Update the Column dropdown
        manage_cols_dropdown.options = ["*"] + columns
        manage_cols_dropdown.value = "*"
        manage_cols_dropdown.disabled = False
        
        # Update the Filter dropdown
        manage_filter_col_dropdown.options = ["(None)"] + columns
        manage_filter_col_dropdown.value = "(None)"
        manage_filter_col_dropdown.disabled = False
    else:
        # Disable dropdowns if no table is selected
        manage_cols_dropdown.options = []
        manage_cols_dropdown.value = None
        manage_cols_dropdown.disabled = True
        
        manage_filter_col_dropdown.options = []
        manage_filter_col_dropdown.value = None
        manage_filter_col_dropdown.disabled = True


def on_manage_query_click(b):
    """Called when the 'Run Query' button is clicked."""
    with manage_output:
        clear_output(wait=True)
        
        table = manage_table_dropdown.value
        cols = manage_cols_dropdown.value
        
        join = manage_join_dropdown.value
        if join == "(None)":
            join = None

        # Build the "WHERE" condition string
        cond = None
        filter_col = manage_filter_col_dropdown.value
        
        if filter_col != "(None)":
            filter_op = manage_filter_op_dropdown.value
            filter_val = manage_filter_val_text.value
            
            # Add quotes for text (like 'M') but not for numbers (like 20)
            if not filter_val.replace('.','',1).isdigit():
                filter_val = f"'{filter_val}'"
                
            cond = f'"{filter_col}" {filter_op} {filter_val}'
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            # Run the query using our SQL_Handler.py file
            results = SQL_Handler.data_selection(table, cols, cond, join)
            
            if results:
                # Get column headers for the results table
                if cols == '*':
                    df_cols = get_column_names(table)
                else:
                    df_cols = [cols] 
                    
                if join:
                    df = pd.DataFrame(results)
                    print("Note: Column headers not available for JOIN queries.")
                else:
                    df = pd.DataFrame(results, columns=df_cols)
                
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df) # Display the results as a table
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")

# Tell the dropdown to run our function when its value changes
manage_table_dropdown.observe(on_manage_table_change, names='value')

# Tell the button to run our function when it's clicked
manage_button.on_click(on_manage_query_click)

# This VBox holds the 3-part filter
manage_filter_box = widgets.VBox([
    manage_filter_col_dropdown,
    manage_filter_op_dropdown,
    manage_filter_val_text
], layout={'border': '1px solid #CCC', 'padding': '10px'})

# This VBox holds all the widgets for the first tab
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    widgets.HBox([manage_cols_dropdown, manage_filter_box]), 
    manage_button,
    manage_output
])

# Load the columns for the first table on startup
if all_table_names:
    on_manage_table_change({'new': all_table_names[0]})


# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

def graph_update_columns(change):
    """Called when the graphing 'Table' or 'Plot Type' dropdowns change."""
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value
    
    conn = sqlite3.connect(DB_NAME)
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # Scatter plots only make sense with numbers
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # Bar plots can use any column
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# Tell the dropdowns to run our function when they change
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_update_columns, names="value")

def graph_generate_plot(b):
    """Called when the 'Generate Plot' button is clicked."""
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Filter by Age (18-22) ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df)
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        plt.figure(figsize=(7,4))
        
        # --- Scatter Plot ---
        if kind == "Scatter":
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) 

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close()
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value and len(df) > 1:
                m, b = np.polyfit(df[x_col], df[y_col], 1)
                plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar Plot ---
        else:
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): 
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all():
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col])
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close()
                return
            
            # Helper function to decide if a column is "categorical" (text or few numbers)
            def is_col_categorical(col_name, unique_thresh=25):
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            
            # Case 1: Y-axis is categorical (e.g., X=age, Y=gender)
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count")

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: X-axis is categorical, Y-axis is numeric (e.g., X=age, Y=student_id)
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].count() 
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel("Number of Students") 

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both axes are numeric (and not categorical)
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}")

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show() # Display the plot

# Tell the button to run our function when it's clicked
graph_generate_button.on_click(graph_generate_plot)

# This VBox holds all the widgets for the second tab
data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output
])

# Load the columns for the first table on startup
if all_table_names:
    graph_update_columns(None)


# --- 5. ASSEMBLE AND DISPLAY THE TABS ---
tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, 'Data Management')
tab_container.set_title(1, 'Data Graphing')

# This is the final step: display the tabs!
display(tab_container)

Successfully imported SQL_Handler.py


Tab(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'P…

In [1]:
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

# Connect to database
conn = sqlite3.connect("students.db")

# Load a table and filter age if present
def load_table(table_name):
    df = pd.read_sql(f"SELECT * FROM {table_name}", conn)
    
    # Age filter 18–22 only
    if "age" in df.columns:
        df = df[(df["age"] >= 18) & (df["age"] <= 22)]
    
    return df


# -------------------------
# Data Management Tab
# -------------------------
table_dropdown_dm = widgets.Dropdown(description="Table:")
output_dm = widgets.Output()
refresh_button = widgets.Button(description="Refresh Tables")

def refresh_tables(_):
    cur = conn.cursor()
    cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
    table_dropdown_dm.options = [row[0] for row in cur.fetchall()]

def show_table(change):
    with output_dm:
        output_dm.clear_output()
        if table_dropdown_dm.value:
            df = load_table(table_dropdown_dm.value)
            display(df)

refresh_button.on_click(refresh_tables)
table_dropdown_dm.observe(show_table, names="value")

data_tab = widgets.VBox([
    widgets.HBox([refresh_button, table_dropdown_dm]),
    output_dm
])


# -------------------------
# Graphing Tab
# -------------------------
plot_dropdown = widgets.Dropdown(
    options=["Bar", "Line", "Scatter"],
    description="Plot:"
)

table_dropdown_graph = widgets.Dropdown(description="Table:")
x_dropdown = widgets.Dropdown(description="X:")
y_dropdown = widgets.Dropdown(description="Y:")
plot_button = widgets.Button(description="Plot", button_style="success")

output_graph = widgets.Output()
output_stats = widgets.Output()

def load_columns(change):
    table = table_dropdown_graph.value
    if table:
        df = load_table(table)
        cols = list(df.columns)
        x_dropdown.options = cols
        y_dropdown.options = cols

table_dropdown_graph.observe(load_columns, names="value")


def show_stats(df, col):
    """Shows basic stats for a numeric column."""
    series = df[col]

    if series.dtype == object:
        return  # No stats for letters/words

    print(f"Statistics for {col}:")
    print("Min:", series.min())
    print("Max:", series.max())
    print("Average:", round(series.mean(), 2))
    print("Count:", series.count())


def plot_graph(_):
    with output_graph:
        output_graph.clear_output()
    with output_stats:
        output_stats.clear_output()

    table = table_dropdown_graph.value
    x_col = x_dropdown.value
    y_col = y_dropdown.value

    df = load_table(table)

    # -------- Make graph --------
    with output_graph:
        plt.figure(figsize=(8,5))

        if df[y_col].dtype == object:
            # Group and count for category plots
            grouped = df.groupby([x_col, y_col]).size().unstack(fill_value=0)
            grouped.plot(kind="bar")
            plt.ylabel("Count")
            stats_col = x_col  # numeric column
        else:
            df.plot(kind=plot_dropdown.value.lower(), x=x_col, y=y_col)
            stats_col = y_col

        plt.title(f"{y_col} vs {x_col}")
        plt.tight_layout()
        plt.show()

    # -------- Show simple stats --------
    with output_stats:
        show_stats(df, stats_col)


graph_tab = widgets.VBox([
    widgets.HBox([table_dropdown_graph, plot_dropdown]),
    widgets.HBox([x_dropdown, y_dropdown]),
    plot_button,
    output_graph,
    output_stats
])


# -------------------------
# Tabs
# -------------------------
tabs = widgets.Tab()
tabs.children = [data_tab, graph_tab]
tabs.set_title(0, "Data Management")
tabs.set_title(1, "Graphing")
display(tabs)

refresh_tables(None)


Tab(children=(VBox(children=(HBox(children=(Button(description='Refresh Tables', style=ButtonStyle()), Dropdow…

In [2]:
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

# ---------------------------------------------------
# CONNECT TO DATABASE
# ---------------------------------------------------
conn = sqlite3.connect("students.db")

# Load table and filter age
def load_table(table_name):
    df = pd.read_sql(f"SELECT * FROM {table_name}", conn)

    # If table has an age column, filter 18–22
    if "age" in df.columns:
        df = df[(df["age"] >= 18) & (df["age"] <= 22)]

    return df


# ---------------------------------------------------
# DATA MANAGEMENT TAB
# ---------------------------------------------------

table_dropdown_dm = widgets.Dropdown(description="Table:")
output_dm = widgets.Output()

refresh_button = widgets.Button(description="Refresh Tables")

def refresh_tables(_):
    cur = conn.cursor()
    cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cur.fetchall()]
    table_dropdown_dm.options = tables

def show_table(change):
    with output_dm:
        output_dm.clear_output()
        if table_dropdown_dm.value:
            df = load_table(table_dropdown_dm.value)
            display(df)

refresh_button.on_click(refresh_tables)
table_dropdown_dm.observe(show_table, names="value")

data_tab = widgets.VBox([
    widgets.HBox([refresh_button, table_dropdown_dm]),
    output_dm
])


# ---------------------------------------------------
# GRAPHING TAB
# ---------------------------------------------------

plot_dropdown = widgets.Dropdown(
    options=["Bar", "Line", "Scatter"],
    description="Plot:"
)

table_dropdown_graph = widgets.Dropdown(description="Table:")
x_dropdown = widgets.Dropdown(description="X:")
y_dropdown = widgets.Dropdown(description="Y:")
plot_button = widgets.Button(description="Plot", button_style="success")

output_graph = widgets.Output()
output_stats = widgets.Output()

def load_columns(change):
    table = table_dropdown_graph.value
    if table:
        df = load_table(table)
        x_dropdown.options = df.columns
        y_dropdown.options = df.columns

table_dropdown_graph.observe(load_columns, names="value")


# Show statistics for numeric data
def show_stats(df, col):
    series = df[col]

    # Skip categories (like gender, major, etc.)
    if series.dtype == object:
        print(f"No numeric stats for '{col}' because it is categorical.")
        return

    print(f"Statistics for {col}:")
    print("Min:", series.min())
    print("Max:", series.max())
    print("Average:", round(series.mean(), 2))
    print("Count:", series.count())


def plot_graph(_):
    with output_graph:
        output_graph.clear_output()
    with output_stats:
        output_stats.clear_output()

    table = table_dropdown_graph.value
    x_col = x_dropdown.value
    y_col = y_dropdown.value

    df = load_table(table)

    # -----------------------------
    #         MAKE GRAPH
    # -----------------------------
    with output_graph:
        plt.figure(figsize=(8,5))

        # If Y is text (categorical), show counts
        if df[y_col].dtype == object:
            grouped = df.groupby([x_col, y_col]).size().unstack(fill_value=0)
            grouped.plot(kind="bar")
            plt.ylabel("Count")
            stats_col = x_col  # numeric axis

        else:
            df.plot(kind=plot_dropdown.value.lower(), x=x_col, y=y_col)
            stats_col = y_col  # numeric axis

        plt.title(f"{y_col} vs {x_col}")
        plt.tight_layout()
        plt.show()

    # -----------------------------
    #        SHOW STATS
    # -----------------------------
    with output_stats:
        show_stats(df, stats_col)


graph_tab = widgets.VBox([
    widgets.HBox([table_dropdown_graph, plot_dropdown]),
    widgets.HBox([x_dropdown, y_dropdown]),
    plot_button,
    output_graph,
    output_stats
])


# ---------------------------------------------------
# MAIN TABS
# ---------------------------------------------------
tabs = widgets.Tab()
tabs.children = [data_tab, graph_tab]
tabs.set_title(0, "Data Management")
tabs.set_title(1, "Graphing")

display(tabs)

# Load tables on startup
refresh_tables(None)


Tab(children=(VBox(children=(HBox(children=(Button(description='Refresh Tables', style=ButtonStyle()), Dropdow…

In [5]:
#import streamlit as st
import sqlite3
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

DB_NAME = "Dataset.db"

# ----------------------------------
# Helper functions
# ----------------------------------

def get_table_names():
    try:
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(
            "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';",
            conn
        )
        conn.close()
        return df["name"].tolist()
    except:
        return []

def get_column_names(table):
    try:
        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f'PRAGMA table_info("{table}")', conn)
        conn.close()
        return df["name"].tolist()
    except:
        return []

def load_table(table):
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
    conn.close()
    return df

# ----------------------------------
# Streamlit UI
# ----------------------------------

st.title("📊 Database Graphing Tool (Streamlit Version)")
st.write("Simple + beginner friendly interface")

tables = get_table_names()

if not tables:
    st.error("No tables found in Database!")
    st.stop()

# --- Select Table ---
table = st.selectbox("Select Table", tables)

df = load_table(table)
columns = df.columns.tolist()

# --- Select Plot Type ---
plot_type = st.selectbox("Plot Type", ["Bar", "Scatter"])

# --- X & Y selection ---
x_col = st.selectbox("X-axis", columns)
y_col = st.selectbox("Y-axis", columns)

# --- Age filter automatically applied ---
if "age" in df.columns:
    df["age"] = pd.to_numeric(df["age"], errors="coerce")
    original = len(df)
    df = df.dropna(subset=["age"])
    df = df[(df["age"] >= 18) & (df["age"] <= 22)]
    st.info(f"Age filter active: Showing {len(df)} of {original} rows (age 18–22).")

# ----------------------------------
# Generate Plot Button
# ----------------------------------

if st.button("Generate Plot"):
    
    # Convert numeric where possible
    df[x_col] = pd.to_numeric(df[x_col], errors="ignore")
    df[y_col] = pd.to_numeric(df[y_col], errors="ignore")

    st.subheader("📈 Generated Plot")

    fig, ax = plt.subplots(figsize=(7,4))

    # ----------------------------------
    # SCATTER PLOT
    # ----------------------------------
    if plot_type == "Scatter":
        # Force numeric for scatter
        df = df.dropna(subset=[x_col, y_col])
        try:
            x = pd.to_numeric(df[x_col], errors="coerce")
            y = pd.to_numeric(df[y_col], errors="coerce")
            df = df.dropna(subset=[x_col, y_col])
            ax.scatter(df[x_col], df[y_col], alpha=0.7)

            # Best fit line
            if len(df) > 1:
                m, b = np.polyfit(df[x_col], df[y_col], 1)
                ax.plot(df[x_col], m*df[x_col] + b, linewidth=2)
        except:
            st.error("Both X and Y must be numeric for scatter.")
            st.stop()

    # ----------------------------------
    # BAR PLOT
    # ----------------------------------
    else:
        # if x_col is categorical
        try:
            ax.bar(df[x_col].astype(str), df[y_col])
            plt.xticks(rotation=45, ha="right")
        except:
            st.error("Cannot create bar plot with selected columns.")
            st.stop()

    ax.set_title(f"{y_col} vs {x_col}")
    ax.set_xlabel(x_col)
    ax.grid(alpha=0.3)
    st.pyplot(fig)

    # ----------------------------------
    # STATISTICS BOX
    # ----------------------------------
    st.subheader("📊 Statistics (Based on X-axis Column)")

    try:
        numeric_series = pd.to_numeric(df[x_col], errors="coerce").dropna()

        if numeric_series.empty:
            st.info("X-axis is not numeric — No numeric stats available.")
        else:
            st.write(f"**Column:** `{x_col}`")
            st.write(f"- **Count:** {len(numeric_series)}")
            st.write(f"- **Min:** {numeric_series.min()}")
            st.write(f"- **Max:** {numeric_series.max()}")
            st.write(f"- **Mean:** {numeric_series.mean():.2f}")
            st.write(f"- **Median:** {numeric_series.median():.2f}")

    except:
        st.info("No numeric statistics available.")


NameError: name 'st' is not defined

In [6]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL_Handler.py file ---
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found.")

# --- 2. Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    if not table_name: return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

all_table_names = get_table_names()

# --- 3. DATA MANAGEMENT TAB (NO CHANGES NEEDED) ---
manage_table_dropdown = widgets.Dropdown(description='Table:', options=all_table_names)
manage_join_dropdown = widgets.Dropdown(description='Join Table:', options=["(None)"] + all_table_names, value="(None)")
manage_cols_dropdown = widgets.Dropdown(description='Column:', options=[], value=None, disabled=True)
manage_filter_col_dropdown = widgets.Dropdown(options=[], description='Filter Column:', disabled=True)
manage_filter_op_dropdown = widgets.Dropdown(options=['=', '!=', '>', '<', '>=', '<=', 'LIKE'], description='Operator:', value='=')
manage_filter_val_text = widgets.Text(description='Value:', placeholder="e.g., 20 or 'M'")
manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()

def on_manage_table_change(change):
    table_name = change['new']
    if table_name:
        columns = get_column_names(table_name)
        manage_cols_dropdown.options = ["*"] + columns
        manage_cols_dropdown.value = "*"
        manage_cols_dropdown.disabled = False
        manage_filter_col_dropdown.options = ["(None)"] + columns
        manage_filter_col_dropdown.value = "(None)"
        manage_filter_col_dropdown.disabled = False
    else:
        manage_cols_dropdown.options = []
        manage_cols_dropdown.value = None
        manage_cols_dropdown.disabled = True
        manage_filter_col_dropdown.options = []
        manage_filter_col_dropdown.value = None
        manage_filter_col_dropdown.disabled = True

def on_manage_query_click(b):
    with manage_output:
        clear_output(wait=True)
        
        table = manage_table_dropdown.value
        cols = manage_cols_dropdown.value
        join = manage_join_dropdown.value
        if join == "(None)": join = None

        cond = None
        filter_col = manage_filter_col_dropdown.value
        
        if filter_col != "(None)":
            filter_op = manage_filter_op_dropdown.value
            filter_val = manage_filter_val_text.value
            if not filter_val.replace('.','',1).isdigit():
                filter_val = f"'{filter_val}'"
            cond = f'"{filter_col}" {filter_op} {filter_val}'
            
        try:
            results = SQL_Handler.data_selection(table, cols, cond, join)
            if results:
                if cols == '*':
                    df_cols = get_column_names(table)
                else:
                    df_cols = [cols]
                if join:
                    df = pd.DataFrame(results)
                else:
                    df = pd.DataFrame(results, columns=df_cols)
                display(df)
            else:
                print("\nQuery executed, but returned no results.")
        except Exception as e:
            print(f"\nAn error occurred: {e}")

manage_table_dropdown.observe(on_manage_table_change, names='value')
manage_button.on_click(on_manage_query_click)

manage_filter_box = widgets.VBox([
    manage_filter_col_dropdown,
    manage_filter_op_dropdown,
    manage_filter_val_text
], layout={'border': '1px solid #CCC', 'padding': '10px'})

data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    widgets.HBox([manage_cols_dropdown, manage_filter_box]),
    manage_button,
    manage_output
])

if all_table_names:
    on_manage_table_change({'new': all_table_names[0]})


# --- 4. DATA GRAPHING TAB WITH STATS ADDED ---
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

### NEW: STATS AREA ###
graph_stats_output = widgets.Output()

def graph_update_columns(change):
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value
    
    conn = sqlite3.connect(DB_NAME)
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except:
        graph_x_axis.options = []
        graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex'
    else:
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none'

    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[-1]

graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_update_columns, names="value")

def graph_generate_plot(b):
    with graph_output:
        clear_output(wait=True)
    with graph_stats_output:
        clear_output(wait=True)

    table = graph_table_dropdown.value
    x_col = graph_x_axis.value
    y_col = graph_y_axis.value
    kind = graph_plot_type.value
    
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
    conn.close()

    # Filter by age 18–22
    if "age" in df.columns:
        df["age"] = pd.to_numeric(df["age"], errors="coerce")
        df = df.dropna(subset=["age"])
        df = df[(df["age"] >= 18) & (df["age"] <= 22)]
        df["age"] = df["age"].round()

    plt.figure(figsize=(7,4))

    # --- PLOTTING (NO CHANGE) ---
    if kind == "Scatter":
        df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
        df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
        df = df.dropna(subset=[x_col, y_col])
        plt.scatter(df[x_col], df[y_col], alpha=0.7)

        if graph_best_fit.value and len(df) > 1:
            m, b = np.polyfit(df[x_col], df[y_col], 1)
            plt.plot(df[x_col], m*df[x_col] + b)

    else:  
        df[x_col] = pd.to_numeric(df[x_col], errors="ignore")
        df[y_col] = pd.to_numeric(df[y_col], errors="ignore")
        df = df.dropna(subset=[x_col, y_col])
        counts = df.groupby(x_col)[y_col].count()
        positions = np.arange(len(counts))
        plt.bar(positions, counts.values)
        plt.xticks(positions, counts.index.astype(str), rotation=45, ha="right")

    plt.title(f"{y_col} vs {x_col} ({kind})")
    plt.xlabel(x_col)
    plt.grid(alpha=0.3)
    plt.tight_layout()

    with graph_output:
        plt.show()

    ### NEW: DISPLAY STATS ###
    with graph_stats_output:
        print("📊 DATA STATISTICS")
        print("---------------------")

        # Only show stats for numeric columns
        for col in [x_col, y_col]:
            try:
                numeric_series = pd.to_numeric(df[col], errors='coerce').dropna()
                if len(numeric_series) > 0:
                    print(f"\n➡ {col}:")
                    print(f"   • Min:   {numeric_series.min()}")
                    print(f"   • Max:   {numeric_series.max()}")
                    print(f"   • Mean:  {round(numeric_series.mean(), 2)}")
                    print(f"   • Count: {numeric_series.count()}")
            except:
                pass

data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output,
    graph_stats_output   ### STATS BOX SHOWS HERE ###
])

if all_table_names:
    graph_update_columns(None)


# --- 5. ASSEMBLE TABS ---
tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, 'Data Management')
tab_container.set_title(1, 'Data Graphing')
display(tab_container)


Successfully imported SQL_Handler.py


Tab(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'P…

In [7]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL_Handler.py file ---
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found.")

# --- 2. Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [t[0] for t in cursor.fetchall() if t[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except:
        return []

def get_column_names(table):
    if not table:
        return []
    conn = sqlite3.connect(DB_NAME)
    cursor = conn.cursor()
    cursor.execute(f'PRAGMA table_info("{table}")')
    cols = [c[1] for c in cursor.fetchall()]
    conn.close()
    return cols

all_table_names = get_table_names()


# ===============================
# --- DATA MANAGEMENT TAB ---
# ===============================

manage_table_dropdown = widgets.Dropdown(description='Table:', options=all_table_names)
manage_join_dropdown = widgets.Dropdown(description='Join Table:', options=["(None)"] + all_table_names)
manage_cols_dropdown = widgets.Dropdown(description='Column:', options=[], disabled=True)
manage_filter_col_dropdown = widgets.Dropdown(description='Filter Column:', options=[], disabled=True)
manage_filter_op_dropdown = widgets.Dropdown(description='Operator:', options=['=', '!=', '>', '<', '>=', '<=', 'LIKE'])
manage_filter_val_text = widgets.Text(description='Value:')
manage_button = widgets.Button(description="Run Query")
manage_output = widgets.Output()

def on_manage_table_change(change):
    table = change["new"]
    columns = get_column_names(table)

    manage_cols_dropdown.options = ["*"] + columns
    manage_cols_dropdown.value = "*"
    manage_cols_dropdown.disabled = False

    manage_filter_col_dropdown.options = ["(None)"] + columns
    manage_filter_col_dropdown.value = "(None)"
    manage_filter_col_dropdown.disabled = False

manage_table_dropdown.observe(on_manage_table_change, names="value")

def on_manage_query_click(b):
    with manage_output:
        clear_output()

        table = manage_table_dropdown.value
        cols = manage_cols_dropdown.value
        join = manage_join_dropdown.value if manage_join_dropdown.value != "(None)" else None

        cond = None
        fc = manage_filter_col_dropdown.value
        if fc != "(None)":
            fv = manage_filter_val_text.value
            if not fv.replace(".", "", 1).isdigit():
                fv = f"'{fv}'"
            cond = f"{fc} {manage_filter_op_dropdown.value} {fv}"

        results = SQL_Handler.data_selection(table, cols, cond, join)

        if results:
            if cols == "*":
                df_cols = get_column_names(table)
            else:
                df_cols = [cols]
            df = pd.DataFrame(results, columns=df_cols)
            display(df)
        else:
            print("No results.")

manage_button.on_click(on_manage_query_click)

data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    widgets.HBox([manage_cols_dropdown, manage_filter_col_dropdown, manage_filter_op_dropdown, manage_filter_val_text]),
    manage_button,
    manage_output
])

if all_table_names:
    on_manage_table_change({'new': all_table_names[0]})


# ===============================
# --- DATA GRAPHING TAB ---
# ===============================

graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style='success')

graph_output = widgets.Output()
graph_stats_output = widgets.Output()


def graph_update_columns(change):
    table = graph_table_dropdown.value
    conn = sqlite3.connect(DB_NAME)
    df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    conn.close()

    p = graph_plot_type.value
    if p == "Scatter":
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
    else:
        graph_x_axis.options = df.columns.tolist()
        graph_y_axis.options = df.columns.tolist()

    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if len(graph_y_axis.options) > 1:
        graph_y_axis.value = graph_y_axis.options[1]


graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_update_columns, names="value")


def graph_generate_plot(b):
    with graph_output:
        clear_output()

        table = graph_table_dropdown.value
        xcol = graph_x_axis.value
        ycol = graph_y_axis.value
        kind = graph_plot_type.value

        conn = sqlite3.connect(DB_NAME)
        df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
        conn.close()

        plt.figure(figsize=(7,4))

        if kind == "Scatter":
            df[xcol] = pd.to_numeric(df[xcol], errors="coerce")
            df[ycol] = pd.to_numeric(df[ycol], errors="coerce")
            df = df.dropna(subset=[xcol, ycol])

            plt.scatter(df[xcol], df[ycol])
            if graph_best_fit.value:
                if len(df) > 1:
                    m, b = np.polyfit(df[xcol], df[ycol], 1)
                    plt.plot(df[xcol], m*df[xcol]+b, color='red')

        else:  # Bar
            df = df.dropna(subset=[xcol])
            counts = df[xcol].value_counts().sort_index()
            plt.bar(counts.index.astype(str), counts.values)

        plt.title(f"{ycol} vs {xcol}")
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show()

    # --- SHOW STATS ---
    with graph_stats_output:
        clear_output()

        numeric = pd.to_numeric(df[xcol], errors="coerce").dropna()
        if numeric.empty:
            print("No numeric stats available.")
            return

        print("📊 **Data Statistics**")
        print(f"Min:   {numeric.min()}")
        print(f"Max:   {numeric.max()}")
        print(f"Mean:  {round(numeric.mean(), 2)}")
        print(f"Count: {numeric.count()}")


# VERY IMPORTANT — THIS MAKES THE BUTTON WORK
graph_generate_button.on_click(graph_generate_plot)


data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output,
    graph_stats_output   # statistics box
])

if all_table_names:
    graph_update_columns(None)


# ===============================
# --- SHOW THE TABS ---
# ===============================

tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, "Data Management")
tab_container.set_title(1, "Data Graphing")
display(tab_container)


Successfully imported SQL_Handler.py


Tab(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'P…

In [8]:
# --- 1. Import Libraries ---
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import GridspecLayout
from IPython.display import display, clear_output
import io
import numpy as np

# --- Import your custom SQL_Handler.py file ---
try:
    import SQL_Handler
    print("Successfully imported SQL_Handler.py")
except ImportError:
    print("ERROR: SQL_Handler.py not found.")

# --- 2. Database Helper Functions ---
DB_NAME = 'Dataset.db'

def get_table_names():
    """Fetches all table names from the database."""
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = [table[0].strip() for table in cursor.fetchall() if table[0] != 'sqlite_sequence']
        conn.close()
        return tables
    except Exception as e:
        print(f"Database error: {e}")
        return []

def get_column_names(table_name):
    """Fetches all column names for a specific table."""
    if not table_name:
        return []
    try:
        conn = sqlite3.connect(DB_NAME)
        cursor = conn.cursor()
        cursor.execute(f'PRAGMA table_info("{table_name}")')
        columns = [col[1].strip() for col in cursor.fetchall()]
        conn.close()
        return columns
    except Exception as e:
        print(f"Database error: {e}")
        return []

# Get a list of tables to use in our dropdowns
all_table_names = get_table_names()


# --- 3. CREATE "DATA MANAGEMENT" TAB WIDGETS ---
manage_table_dropdown = widgets.Dropdown(
    description='Table:',
    options=all_table_names
)
manage_join_dropdown = widgets.Dropdown(
    description='Join Table:',
    options=["(None)"] + all_table_names,
    value="(None)"
)
manage_cols_dropdown = widgets.Dropdown(
    description='Column:',
    options=[],
    value=None,
    disabled=True 
)
manage_filter_col_dropdown = widgets.Dropdown(
    options=[],
    description='Filter Column:',
    disabled=True
)
manage_filter_op_dropdown = widgets.Dropdown(
    options=['=', '!=', '>', '<', '>=', '<=', 'LIKE'],
    description='Operator:',
    value='='
)
manage_filter_val_text = widgets.Text(
    description='Value:',
    placeholder="e.g., 20 or 'M'"
)
manage_button = widgets.Button(description='Run Query')
manage_output = widgets.Output()


def on_manage_table_change(change):
    """Called when the 'Table' dropdown changes."""
    table_name = change['new']
    if table_name:
        columns = get_column_names(table_name)
        
        # Update the Column dropdown
        manage_cols_dropdown.options = ["*"] + columns
        manage_cols_dropdown.value = "*"
        manage_cols_dropdown.disabled = False
        
        # Update the Filter dropdown
        manage_filter_col_dropdown.options = ["(None)"] + columns
        manage_filter_col_dropdown.value = "(None)"
        manage_filter_col_dropdown.disabled = False
    else:
        # Disable dropdowns if no table is selected
        manage_cols_dropdown.options = []
        manage_cols_dropdown.value = None
        manage_cols_dropdown.disabled = True
        
        manage_filter_col_dropdown.options = []
        manage_filter_col_dropdown.value = None
        manage_filter_col_dropdown.disabled = True


def on_manage_query_click(b):
    """Called when the 'Run Query' button is clicked."""
    with manage_output:
        clear_output(wait=True)
        
        table = manage_table_dropdown.value
        cols = manage_cols_dropdown.value
        
        join = manage_join_dropdown.value
        if join == "(None)":
            join = None

        # Build the "WHERE" condition string
        cond = None
        filter_col = manage_filter_col_dropdown.value
        
        if filter_col != "(None)":
            filter_op = manage_filter_op_dropdown.value
            filter_val = manage_filter_val_text.value
            
            # Add quotes for text (like 'M') but not for numbers (like 20)
            if not filter_val.replace('.','',1).isdigit():
                filter_val = f"'{filter_val}'"
                
            cond = f'"{filter_col}" {filter_op} {filter_val}'
            
        print(f"Querying: SELECT {cols} FROM {table}...")
        if join:
            print(f"Joining with: {join}")
        if cond:
            print(f"Condition: WHERE {cond}")
            
        try:
            # Run the query using our SQL_Handler.py file
            results = SQL_Handler.data_selection(table, cols, cond, join)
            
            if results:
                # Get column headers for the results table
                if cols == '*':
                    df_cols = get_column_names(table)
                else:
                    df_cols = [cols] 
                    
                if join:
                    df = pd.DataFrame(results)
                    print("Note: Column headers not available for JOIN queries.")
                else:
                    df = pd.DataFrame(results, columns=df_cols)
                
                print(f"\nSuccess! Found {len(results)} rows.")
                display(df) # Display the results as a table
            else:
                print("\nQuery executed, but returned no results.")
                
        except Exception as e:
            print(f"\nAn error occurred: {e}")

# Tell the dropdown to run our function when its value changes
manage_table_dropdown.observe(on_manage_table_change, names='value')

# Tell the button to run our function when it's clicked
manage_button.on_click(on_manage_query_click)

# This VBox holds the 3-part filter
manage_filter_box = widgets.VBox([
    manage_filter_col_dropdown,
    manage_filter_op_dropdown,
    manage_filter_val_text
], layout={'border': '1px solid #CCC', 'padding': '10px'})

# This VBox holds all the widgets for the first tab
data_management_tab = widgets.VBox([
    widgets.HBox([manage_table_dropdown, manage_join_dropdown]),
    widgets.HBox([manage_cols_dropdown, manage_filter_box]), 
    manage_button,
    manage_output
])

# Load the columns for the first table on startup
if all_table_names:
    on_manage_table_change({'new': all_table_names[0]})


# --- 4. CREATE "DATA GRAPHING" TAB WIDGETS ---
graph_table_dropdown = widgets.Dropdown(options=all_table_names, description="Table:")
graph_plot_type = widgets.Dropdown(options=["Bar", "Scatter"], description="Plot Type:")
graph_x_axis = widgets.Dropdown(description="X-Axis:")
graph_y_axis = widgets.Dropdown(description="Y-Axis:")
graph_best_fit = widgets.Checkbox(value=True, description="Add Best-Fit Line")
graph_generate_button = widgets.Button(description="Generate Plot", button_style="success")
graph_output = widgets.Output()

def graph_update_columns(change):
    """Called when the graphing 'Table' or 'Plot Type' dropdowns change."""
    table = graph_table_dropdown.value
    ptype = graph_plot_type.value
    
    conn = sqlite3.connect(DB_NAME)
    try:
        df = pd.read_sql_query(f"SELECT * FROM '{table}' LIMIT 50", conn)
    except Exception as e:
        with graph_output:
            print(f"Error reading table {table}: {e}")
            graph_x_axis.options = []
            graph_y_axis.options = []
        return
    finally:
        conn.close()

    if ptype == "Scatter":
        # Scatter plots only make sense with numbers
        numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
        graph_x_axis.options = numeric_cols
        graph_y_axis.options = numeric_cols
        graph_best_fit.layout.display = 'flex' # Show checkbox
    else:
        # Bar plots can use any column
        all_cols = df.columns.tolist()
        graph_x_axis.options = all_cols
        graph_y_axis.options = all_cols
        graph_best_fit.layout.display = 'none' # Hide checkbox

    # Set default values
    if graph_x_axis.options:
        graph_x_axis.value = graph_x_axis.options[0]
    if graph_y_axis.options and len(graph_y_axis.options) > 1:
        graph_y_axis.value = graph_y_axis.options[1]
    elif graph_y_axis.options:
        graph_y_axis.value = graph_y_axis.options[0]

# Tell the dropdowns to run our function when they change
graph_table_dropdown.observe(graph_update_columns, names="value")
graph_plot_type.observe(graph_update_columns, names="value")

def graph_generate_plot(b):
    """Called when the 'Generate Plot' button is clicked."""
    with graph_output:
        clear_output(wait=True)
        
        table = graph_table_dropdown.value
        x_col = graph_x_axis.value
        y_col = graph_y_axis.value
        kind = graph_plot_type.value
        
        if not x_col or not y_col:
            print("⚠️ Please select valid columns.")
            return

        try:
            conn = sqlite3.connect(DB_NAME)
            df = pd.read_sql_query(f"SELECT * FROM '{table}'", conn)
            conn.close()
        except Exception as e:
            print(f"Error querying database: {e}")
            return

        # --- Filter by Age (18-22) ---
        if "age" in df.columns:
            df["age"] = pd.to_numeric(df["age"], errors="coerce")
            original_count = len(df)
            df = df.dropna(subset=['age'])
            df = df[(df["age"] >= 18) & (df["age"] <= 22)]
            if 'age' in df.columns:
                df['age'] = df['age'].round()
            print(f"Filter applied: Kept {len(df)} of {original_count} students aged 18–22.")

        # --- THIS IS THE NEWLY ADDED STATS BLOCK ---
        stats_box = widgets.VBox([], layout={'border': '1px solid #CCC', 'padding': '10px', 'margin_bottom': '10px'})
        stats_widgets = []

        # Calculate stats for X-Axis
        x_data_numeric = pd.to_numeric(df[x_col], errors='coerce')
        if not x_data_numeric.isnull().all(): # Check if it's a numeric column
            x_stats = x_data_numeric.describe()
            stats_widgets.append(widgets.HTML(f"<b>Stats for {x_col} (X-Axis):</b>"))
            stats_widgets.append(widgets.HTML(f"&nbsp; &nbsp; <b>Count:</b> {x_stats['count']:.0f}"))
            stats_widgets.append(widgets.HTML(f"&nbsp; &nbsp; <b>Mean:</b> {x_stats['mean']:.2f}"))
            stats_widgets.append(widgets.HTML(f"&nbsp; &nbsp; <b>Min:</b> {x_stats['min']:.2f}"))
            stats_widgets.append(widgets.HTML(f"&nbsp; &nbsp; <b>Max:</b> {x_stats['max']:.2f}"))
        else:
            # It's a categorical (text) column
            stats_widgets.append(widgets.HTML(f"<b>Stats for {x_col} (X-Axis):</b>"))
            stats_widgets.append(widgets.HTML(f"&nbsp; &nbsp; <b>Count:</b> {df[x_col].count()}"))
            stats_widgets.append(widgets.HTML(f"&nbsp; &nbsp; <b>Unique Values:</b> {df[x_col].nunique()}"))

        # Calculate stats for Y-Axis (and avoid repeating X)
        if x_col != y_col:
            y_data_numeric = pd.to_numeric(df[y_col], errors='coerce')
            if not y_data_numeric.isnull().all(): # Check if it's a numeric column
                y_stats = y_data_numeric.describe()
                stats_widgets.append(widgets.HTML(f"<b>Stats for {y_col} (Y-Axis):</b>"))
                stats_widgets.append(widgets.HTML(f"&nbsp; &nbsp; <b>Count:</b> {y_stats['count']:.0f}"))
                stats_widgets.append(widgets.HTML(f"&nbsp; &nbsp; <b>Mean:</b> {y_stats['mean']:.2f}"))
                stats_widgets.append(widgets.HTML(f"&nbsp; &nbsp; <b>Min:</b> {y_stats['min']:.2f}"))
                stats_widgets.append(widgets.HTML(f"&nbsp; &nbsp; <b>Max:</b> {y_stats['max']:.2f}"))
            else:
                # It's a categorical (text) column
                stats_widgets.append(widgets.HTML(f"<b>Stats for {y_col} (Y-Axis):</b>"))
                stats_widgets.append(widgets.HTML(f"&nbsp; &nbsp; <b>Count:</b> {df[y_col].count()}"))
                stats_widgets.append(widgets.HTML(f"&nbsp; &nbsp; <b>Unique Values:</b> {df[y_col].nunique()}"))

        # Add all the HTML widgets to the box and display it
        stats_box.children = stats_widgets
        display(stats_box)
        # --- END OF NEW STATS BLOCK ---

        plt.figure(figsize=(7,4))
        
        # --- Scatter Plot ---
        if kind == "Scatter":
            df[x_col] = pd.to_numeric(df[x_col], errors='coerce')
            df[y_col] = pd.to_numeric(df[y_col], errors='coerce')
            df = df.dropna(subset=[x_col, y_col]) 

            if df.empty:
                print("No numeric data to plot for scatter.")
                plt.close()
                return
                
            plt.scatter(df[x_col], df[y_col], color="skyblue", alpha=0.7)
            
            if graph_best_fit.value and len(df) > 1:
                m, b = np.polyfit(df[x_col], df[y_col], 1)
                plt.plot(df[x_col], m*df[x_col]+b, color="red")

        # --- Bar Plot ---
        else:
            x_data = pd.to_numeric(df[x_col], errors='coerce')
            if x_data.isnull().all(): 
                x_data = df[x_col]
                
            y_data = pd.to_numeric(df[y_col], errors='coerce')
            if y_data.isnull().all():
                y_data = df[y_col]
                
            df[x_col] = x_data
            df[y_col] = y_data
            
            df = df.dropna(subset=[x_col, y_col])
            if df.empty:
                print("No data to plot for bar chart.")
                plt.close()
                return
            
            # Helper function to decide if a column is "categorical" (text or few numbers)
            def is_col_categorical(col_name, unique_thresh=25):
                if pd.api.types.is_object_dtype(df[col_name]):
                    return True
                if pd.api.types.is_numeric_dtype(df[col_name]):
                    if df[col_name].nunique() < unique_thresh:
                        if 'id' not in str(col_name).lower():
                            return True
                return False

            is_x_categorical = is_col_categorical(x_col)
            is_y_categorical = is_col_categorical(y_col)
            
            # Case 1: Y-axis is categorical (e.g., X=age, Y=gender)
            if is_y_categorical:
                grouped = pd.crosstab(df[x_col], df[y_col]) 
                grouped = grouped.sort_index()
                
                indices = np.arange(len(grouped))
                bar_width = 0.8 / len(grouped.columns) 
                
                for i, col in enumerate(grouped.columns):
                    plt.bar(
                        indices + i*bar_width, 
                        grouped[col].values, 
                        width=bar_width, 
                        label=str(col), 
                        alpha=0.8
                    )
                
                plt.legend(title=y_col)
                plt.ylabel("Count")

                labels = grouped.index.astype(str)
                positions = indices + bar_width*(len(grouped.columns)/2 - 0.5)
                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 2: X-axis is categorical, Y-axis is numeric (e.g., X=age, Y=student_id)
            elif is_x_categorical and not is_y_categorical:
                agg_data = df.groupby(x_col)[y_col].count() 
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel("Number of Students") 

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")

            # Case 3: Both axes are numeric (and not categorical)
            else:
                agg_data = df.groupby(x_col)[y_col].count()
                agg_data = agg_data.sort_index()
                
                labels = agg_data.index.astype(str)
                positions = np.arange(len(labels))
                
                plt.bar(positions, agg_data.values, color="skyblue")
                plt.ylabel(f"Count of {y_col}")

                if len(positions) > 50:
                    plt.xticks([])
                else:
                    plt.xticks(positions, labels, rotation=45, ha="right")
            
        plt.title(f"{y_col} vs {x_col} ({kind})")
        plt.xlabel(x_col)
        plt.grid(alpha=0.3)
        plt.tight_layout()
        plt.show() # Display the plot

# Tell the button to run our function when it's clicked
graph_generate_button.on_click(graph_generate_plot)

# This VBox holds all the widgets for the second tab
data_graphing_tab = widgets.VBox([
    widgets.HBox([graph_table_dropdown, graph_plot_type]),
    widgets.HBox([graph_x_axis, graph_y_axis]),
    widgets.HBox([graph_generate_button, graph_best_fit]),
    graph_output
])

# Load the columns for the first table on startup
if all_table_names:
    graph_update_columns(None)


# --- 5. ASSEMBLE AND DISPLAY THE TABS ---
tab_container = widgets.Tab()
tab_container.children = [data_management_tab, data_graphing_tab]
tab_container.set_title(0, 'Data Management')
tab_container.set_title(1, 'Data Graphing')

# This is the final step: display the tabs!
display(tab_container)

Successfully imported SQL_Handler.py


Tab(children=(VBox(children=(HBox(children=(Dropdown(description='Table:', options=('Students', 'Academic', 'P…