In [29]:
%pip install pandas matplotlib seaborn ipywidgets python-dotenv groq scikit-learn
!jupyter nbextension enable --py widgetsnbextension --sys-prefix

Note: you may need to restart the kernel to use updated packages.
Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [30]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
from groq import Groq
import os
import re
import hashlib

from io import BytesIO
from dotenv import load_dotenv
load_dotenv()

True

In [31]:

def generate_plot_suggestions(df):
    print("🧠 [LLM] Generating plot suggestions...")

    column_list = ', '.join(df.columns)
    data_sample = df.head(5).to_string(index=False)

    prompt = f"""
You are an expert data analyst.

The user has uploaded a dataset with the following columns:
{', '.join(df.columns)}

Your task is to recommend 5 meaningful exploratory data analysis (EDA) plots that can be generated from this dataset.

For **each plot**, provide:
1. **Chart Type**: e.g., Scatter Plot, Histogram, Box Plot, Stacked Bar Chart, Bar, Heatmap, Line, Line Plot, Pie, Pie Plot etc.
2. **Columns**: Mention only existing column names from the dataset. Clearly specify the **X-axis and Y-axis** roles, especially for Scatter or Box plots (e.g., "X: Pclass, Y: Fare").
3. **Title**: A short but descriptive title for the plot.
4. **Explanation**: Explain what this plot reveals and why it’s useful.

Use the following format exactly:

Chart Type: <Chart Type>  
Columns: X: <column1>, Y: <column2>  
Title: <plot title>  
Explanation: <Detailed explanation of what the plot tells us>

Repeat for 5 plots.

Make sure:
- You only use columns from the dataset.
- Before writing chart code, check that the columns you use actually exist in the dataset. If any column doesn't exist, skip or correct it using fuzzy matching from column list: [col1, col2, ...]
- For boxplots or bar charts, the X-axis must be **categorical** and the Y-axis **numerical**.
- For scatter plots, both axes must be **numerical**.
""" 

    try:
        client = Groq(api_key=os.getenv("GROQ_API_KEY"))
        response = client.chat.completions.create(
            model="llama3-70b-8192",
            messages=[{"role": "user", "content": prompt}]
        )
        suggestion_text = response.choices[0].message.content
        display(HTML("<h4>📊 Suggested Plot Ideas from LLM:</h4>"))
        print(suggestion_text)
        globals()["plot_suggestions_text"] = suggestion_text
        return suggestion_text

    except Exception as e:
        print(f"❌ LLM Error: {e}")
        return ""


In [32]:
def match_columns_to_df(raw_columns, df_columns):
    matched = []

    for raw_col in raw_columns:
        cleaned = raw_col.strip().lower().replace(" ", "").replace("_", "")
        best_match = None

        for actual_col in df_columns:
            normalized = actual_col.lower().replace(" ", "").replace("_", "")
            if normalized == cleaned:
                best_match = actual_col
                break

        if best_match:
            matched.append(best_match)
        else:
            matched.append(raw_col)  # fallback if no match found

    return matched

def extract_plot_instructions(suggestion_text):
    import re
    print("🧾 Parsing LLM suggestions...")

    instructions = []

    chart_type_map = {
    "scatter plot": "scatter",
    "scatter": "scatter",
    "bar chart": "bar",
    "bar": "bar",
    "stacked bar chart": "stacked_bar",
    "stacked bar": "stacked_bar",
    "stacked histogram": "histogram",   # normalize here
    "histogram": "histogram",
    "box plot": "boxplot",
    "boxplot": "boxplot",
    "line": "line",
    "pie": "pie",
    "heatmap": "heatmap",
    "pairplot": "pairplot"
}

    pattern = re.compile(
        r'Chart Type:\s*(.*?)\s+Columns:\s*(?:(?:X:\s*(.*?),\s*Y:\s*(.*?))|(.+?))\s+Title:\s*(.*?)\s+Explanation:\s*(.*?)(?=Chart Type:|$)',
        re.IGNORECASE | re.DOTALL
    )
    matches = pattern.findall(suggestion_text)

    print(f"🔍 Found {len(matches)} suggestions.")

    for idx, match in enumerate(matches, 1):
        chart_type_raw = match[0].strip()
        x_col = match[1].strip() if match[1] else None
        y_col = match[2].strip() if match[2] else None
        fallback_cols = match[3].strip() if match[3] else ""
        title = match[4].strip()
        desc = match[5].strip()

        # Build unified raw column string
        raw_col_string = f"{x_col}, {y_col}" if x_col and y_col else fallback_cols

        # Clean column names
        raw_col_string = re.sub(r"\s*\(.*?\)", "", raw_col_string)  # Remove parentheses
        raw_col_string = re.sub(r"(?i)(color|hue):\s*\w+", "", raw_col_string)  # Remove 'color:' or 'hue:'
        raw_col_string = re.sub(r"(?i)count of ", "", raw_col_string)
        raw_col_string = re.sub(r"(?i)count\((.*?)\)", r"\1", raw_col_string)
        # Split and normalize
        raw_columns = [col.strip() for col in raw_col_string.split(",") if col.strip()]
        col_list = match_columns_to_df(raw_columns, df.columns)  
                  
        # 🔎 Identify chart type
        chart_type = chart_type_map.get(chart_type_raw.lower(), chart_type_raw.lower())
        
        instruction = {
            'type': chart_type,
            'columns': col_list,
            'title': title,
            'description': desc
        }

        instructions.append(instruction)
        print(f"✅ {idx}. [{chart_type}] Columns: {col_list} → {title}")

    if not instructions:
        print("⚠️ No valid instructions extracted from LLM output.")

    return instructions

In [33]:
import difflib

def parse_column_aliases(col_list, actual_columns):
    """
    Attempts to map aliased column names (e.g., 'Survived and Non-Survived') to actual column names in the dataset.

    Args:
        col_list (list): List of columns suggested by the LLM.
        actual_columns (list): List of real column names from the DataFrame.

    Returns:
        tuple: (resolved_col_list, hue_column if found)
    """
    resolved_cols = []
    hue_col = None

    for col in col_list:
        cleaned = col.lower().strip()

        # Handle phrases like "X and Y" to infer hue column
        if " and " in cleaned or "/" in cleaned:
            # Find a matching real column
            potential_tokens = re.split(r"\s*(?:and|\/)\s*", cleaned)
            for token in potential_tokens:
                match = difflib.get_close_matches(token.strip().title(), actual_columns, n=1)
                if match:
                    hue_col = match[0]
                    break
            continue

        # Direct or fuzzy match
        match = difflib.get_close_matches(cleaned.title(), actual_columns, n=1)
        if match:
            resolved_cols.append(match[0])
        else:
            # If nothing matches, log for debugging
            print(f"⚠️ Could not match column alias: {col}")

    return resolved_cols, hue_col

def resolve_computed_columns(cols, df):
    """
    Detects and computes expressions like 'SibSp + Parch' and adds them
    as new temporary columns in the DataFrame. Returns updated column list.
    """
    updated_cols = []
    for col in cols:
        if '+' in col:
            parts = [p.strip() for p in col.split('+')]
            if all(p in df.columns for p in parts):
                new_col_name = "_plus_".join(parts)
                if new_col_name not in df.columns:
                    df[new_col_name] = df[parts].sum(axis=1)
                updated_cols.append(new_col_name)
            else:
                updated_cols.append(col)  # leave untouched if not all found
        else:
            updated_cols.append(col)
    return updated_cols

In [34]:
# Save Plot Helper
def save_plot(title):
    safe_title = re.sub(r"[^\w\s-]", "", title).strip().replace(" ", "_")
    filepath = f"plots/{safe_title}.png"
    os.makedirs("plots", exist_ok=True)
    plt.savefig(filepath)
    print(f"💾 Plot saved to: {filepath}")

# Plot Dispatch Dictionary

def boxplot_handler(df, cols, title):
    if len(cols) == 2:
        cat, num = None, None
        for col in cols:
            if df[col].nunique() < 10 or df[col].dtype == "object":
                cat = col
            else:
                num = col
        if cat and num:
            sns.boxplot(data=df, x=cat, y=num)
        else:
            sns.boxplot(data=df, x=cols[0], y=cols[1])
        plt.title(title)
        plt.xlabel(cat if cat else cols[0])
        plt.ylabel(num if num else cols[1])
        plt.tight_layout()
        
        plt.show()

def barplot_handler(df, cols, title):
    if len(cols) == 1:
        sns.countplot(data=df, x=cols[0])
    elif len(cols) == 2:
        grouped = df.groupby(cols[0])[cols[1]].count()
        grouped.plot(kind='bar')
        plt.ylabel(f"Count of {cols[1]}")
    plt.title(title)
    plt.xlabel(cols[0])
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

def stacked_barplot_handler(df, cols, title):
    if len(cols) == 2:
        ctab = pd.crosstab(df[cols[0]], df[cols[1]])
        ctab.plot(kind='bar', stacked=True)
        plt.ylabel("Count")
    elif len(cols) == 3:
        ctab = pd.crosstab(df[cols[1]], df[cols[2]], values=df[cols[0]], aggfunc='count').fillna(0)
        ctab.plot(kind='bar', stacked=True)
        plt.ylabel("Count")
    plt.title(title)
    plt.xlabel(cols[0])
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

def histogram_handler(df, cols, title):
    if len(cols) == 1:
        col = cols[0]
        bins = min(20, df[col].nunique())
        sns.histplot(data=df, x=col, bins=bins)
        plt.title(title)
        plt.xlabel(col)
        plt.ylabel("Count")
        plt.tight_layout()
        plt.show()

def scatter_handler(df, cols, title):
    if len(cols) == 2:
        sns.scatterplot(data=df, x=cols[0], y=cols[1])
        plt.title(title)
        plt.xlabel(cols[0])
        plt.ylabel(cols[1])
        plt.tight_layout()
        plt.show()

def heatmap_handler(df, cols, title):
    try:
        if len(cols) >= 2:
            pivot_data = df[cols].dropna()
            correlation_matrix = pd.crosstab(pivot_data[cols[0]], pivot_data[cols[1]])
        else:
            # default to correlation heatmap
            numeric_df = df.select_dtypes(include=['number'])
            correlation_matrix = numeric_df.corr()

        sns.heatmap(correlation_matrix, annot=True, cmap="coolwarm", fmt=".2f")
        plt.title(title)
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"❌ Error plotting heatmap: {e}")

def piechart_handler(df, cols, title):
    try:
        if len(cols) != 1:
            print("⚠️ Pie chart requires exactly 1 categorical column.")
            return

        data = df[cols[0]].value_counts()
        plt.figure(figsize=(6, 6))
        plt.pie(data.values, labels=data.index, autopct='%1.1f%%', startangle=140)
        plt.title(title)
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"❌ Error plotting pie chart: {e}")

def lineplot_handler(df, cols, title):
    try:
        if len(cols) == 2:
            df_plot = df[cols].dropna().sort_values(by=cols[0])
            sns.lineplot(data=df_plot, x=cols[0], y=cols[1])
        elif len(cols) == 1:
            df_plot = df[cols].dropna()
            df_plot.reset_index(drop=True, inplace=True)
            df_plot['index'] = df_plot.index
            sns.lineplot(data=df_plot, x='index', y=cols[0])
        else:
            print("⚠️ Line plot supports 1 or 2 columns only.")
            return

        plt.title(title)
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"❌ Error plotting line plot: {e}")

# Mapping chart types to handler functions
plot_dispatch = {
    "boxplot": boxplot_handler,
    "bar": barplot_handler,
    "stacked_bar": stacked_barplot_handler,
    "histogram": histogram_handler,
    "scatter": scatter_handler,
    "heatmap": heatmap_handler,
    "line": lineplot_handler,
    "lineplot": lineplot_handler,
    "pie": piechart_handler,
    "piechart": piechart_handler
}



In [35]:
def generate_plots_from_instructions(df, instructions):
    print("📈 Generating plots from LLM instructions...")
    display(HTML("<h4>📈 Generated Plots with Explanations:</h4>"))

    if not instructions:
        print("⚠️ No plot instructions provided.")
        return

    any_plotted = False

    for idx, item in enumerate(instructions, 1):
        chart_type = item['type']
        cols = item['columns']
        title = item['title']
        explanation = item.get("description", "")

        print(f"🔧 [{idx}] Type: {chart_type}, Columns: {cols}, Title: {title}")

        # Fuzzy matching fallback
        fallback_cols = []
        for col in cols:
            if col not in df.columns:
                close = difflib.get_close_matches(col, df.columns, n=1)
                fallback_cols.append(close[0] if close else col)
            else:
                fallback_cols.append(col)

        # Attempt resolving computed columns like 'A + B'
        resolved_cols = resolve_computed_columns(fallback_cols, df)

        # Final column check
        if not all(col in df.columns for col in resolved_cols):
            print(f"❌ Still unmatched columns: {resolved_cols}")
            continue
        else:
            cols = resolved_cols
        
        try:
            # ✅ Display title and explanation only (handler shows chart)
            display(HTML(f"<h5>✅ {title}</h5>"))
            
            # Dispatch to the appropriate plotting function
            plot_func = plot_dispatch.get(chart_type)
            if plot_func:
                plot_func(df, cols, title)
                if explanation:
                    display(HTML(f"<p><b>Explanation:</b> {explanation}</p>"))
                any_plotted = True
            else:
                print(f"⚠️ Unsupported chart type: {chart_type}")
        except Exception as e:
            print(f"❌ Error plotting {title}: {e}")

    if not any_plotted:
        display(HTML("<b>⚠️ No valid plots were created from LLM suggestions.</b>"))

In [36]:
sns.set(style="whitegrid")
output_area = widgets.Output()
display(output_area)


Output()

In [37]:

upload = widgets.FileUpload(accept='.csv', multiple=False)
submit_button = widgets.Button(description="📤 Submit File")

def handle_upload(change):
    global df
    with output_area:
        clear_output()
        print("📥 Upload event triggered.")
        if upload.value:
            try:
                uploaded_file = upload.value[0]
                content = uploaded_file['content']
                df = pd.read_csv(BytesIO(content))
                globals()['df_raw'] = df.copy()  # ✅ Store original uncleaned version
                print("✅ File read into DataFrame.")
                
                # Auto-parse datetime columns
                for col in df.columns:
                    if df[col].dtype == 'object':
                        try:
                            df[col] = pd.to_datetime(df[col])
                            print(f"📅 Converted '{col}' to datetime.")
                        except Exception:
                            continue

                df_cleaned = df.dropna().drop_duplicates()
                globals()["df_cleaned"] = df_cleaned
                
                df = df_cleaned.copy()

                display(HTML("<h4>✅ Dataset Uploaded and Cleaned</h4>"))
                display(df_cleaned.head())

                suggestions = generate_plot_suggestions(df_cleaned)
                instructions = extract_plot_instructions(suggestions)
                generate_plots_from_instructions(df_cleaned, instructions)

            except Exception as e:
                print(f"❌ Error in handle_upload: {e}")

submit_button.on_click(handle_upload)
display(HTML("<h3>📁 Upload Your CSV Dataset:</h3>"))
display(upload, submit_button)


FileUpload(value=(), accept='.csv', description='Upload')

Button(description='📤 Submit File', style=ButtonStyle())

In [None]:
import ipywidgets as widgets
from IPython.display import display, Markdown
import traceback
import re
import json
from groq import Groq
import os

# Initialize Groq client
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))

# UI
use_cleaned_toggle = widgets.Checkbox(value=True, description='Use Cleaned Dataset')
query_input = widgets.Textarea(
    placeholder='Ask a question about the dataset...',
    description='🧠 Query:',
    layout=widgets.Layout(width='100%', height='80px'),
)

submit_query = widgets.Button(description="Answer Query", button_style="primary")
output_query = widgets.Output()


display(Markdown("### 💬 Ask a Question About the Dataset"))
display(query_input, use_cleaned_toggle, submit_query, output_query)

# Handler function
def handle_query(_):
    output_query.clear_output()
    user_question = query_input.value.strip()

    if not user_question:
        with output_query:
            print("❗ Please enter a question.")
        return

    prompt = f"""
You are a Python data analyst working with a pandas DataFrame named `df`.

The user asked:
\"\"\"{user_question}\"\"\"

Your job:
1. Classify the query as either:
   - "data": a statistical question (e.g. average, count, percentage)
   - "chart": a visual question (e.g. plot, graph, chart)

2. If the type is "data", return a **single-line** Python expression (like df['col'].mean()) that evaluates the answer using the DataFrame `df`.

3. If the type is "chart", return Python code using matplotlib or seaborn to plot the requested visualization using the DataFrame `df`.

NOTE: When generating charts:
- Always **group and aggregate** data properly (e.g. use `.groupby()` and `.mean()` or `.sum()`).
- Avoid multi-index plots (like `('Male', 0)`); instead, pivot or aggregate to keep axes simple.
- Convert `.value_counts()` or `.groupby()` results to `.plot(kind='bar')` where appropriate.
- Use proper axis labels and titles.
- When the user asks about “top N ages,” interpret that as individual age values, not binned ranges (unless “age groups” is explicitly mentioned).

4. Your full response must be valid JSON using this format:
{{
  "type": "data" or "chart",
  "result": "python code or expression here"
}}

No explanation. Only return a valid JSON object.
"""

    try:
        response = groq_client.chat.completions.create(
            model="llama3-70b-8192",
            messages=[
                {
                    "role": "system",
                    "content": """
        You are a data analyst writing Python code to answer questions about a pandas DataFrame named `df`.

        Always follow this logic:

        1. If the user asks for a statistical answer:
        - Respond with a Python expression like: df['Survived'].mean()

        2. If the user asks for a chart:
        - Use matplotlib or seaborn
        - Group data meaningfully. Do NOT use `.value_counts()` for numeric fields
        - If query mentions 'age group', use pd.cut(df['Age'], bins=[...]) to bin ages
        - Label all axes and add titles
        - Use: df.groupby(), df.pivot_table(), or pd.cut()

        Only output this JSON format:
        {
        "type": "data" or "chart",
        "result": "python code or expression here as a string"
        }
        """
                },
                {
                    "role": "user",
                    "content": user_question
                }
            ],
            temperature=0.1
        )        
        reply = response.choices[0].message.content.strip()

        try:
            result_json = json.loads(reply)
        except json.JSONDecodeError:
            with output_query:
                print("❌ Could not parse LLM output as JSON:")
                print(reply)
            return

        result_type = result_json.get("type")
        result_value = result_json.get("result")

        with output_query:
            if "df" not in globals():
                print("❌ Global DataFrame `df` not found. Please upload the dataset first.")
                return

            if result_type == "data":
                print("🧮 Evaluating:", result_value)
                try:
                    result = eval(result_value, globals(), {"df": df if use_cleaned_toggle.value else df_raw})
                    
                    if isinstance(result, (int, float)):
                        # Single value
                        data_used = "cleaned" if use_cleaned_toggle.value else "original"
                        display(Markdown(f"📊 **Using `{data_used}` dataset**"))
                        display(Markdown(f"📝 **Answer:**\n\n{user_question} → **{result:.2f}**"))
                    
                    elif isinstance(result, pd.Series):
                        # Series → formatted table
                        markdown_table = "\n".join(
                            [f"- **{idx}**: {val}" for idx, val in result.items()]
                        )
                        data_used = "cleaned" if use_cleaned_toggle.value else "original"
                        display(Markdown(f"📊 **Using `{data_used}` dataset**"))
                        display(Markdown(f"📝 **Answer:**\n\n**{user_question}**\n\n{markdown_table}"))
                    
                    elif isinstance(result, pd.DataFrame):
                        # Display a simple HTML table
                        data_used = "cleaned" if use_cleaned_toggle.value else "original"
                        display(Markdown(f"📊 **Using `{data_used}` dataset**"))
                        display(Markdown(f"📝 **Answer:**\n\n**{user_question}**"))
                        display(result.head())

                    else:
                        # Other types
                        data_used = "cleaned" if use_cleaned_toggle.value else "original"
                        display(Markdown(f"📊 **Using `{data_used}` dataset**"))
                        display(Markdown(f"📝 **Answer:**\n\n{user_question} → **{result}**"))

                except Exception as e:
                    print("❌ Error evaluating code:")
                    traceback.print_exc()
                
            elif result_type == "chart":
                print("📊 Executing Chart Code:\n", result_value)
                try:
                    # Patch: Auto-correct column case in code using actual df.columns
                    actual_cols = list(df.columns)
                    for col in actual_cols:
                        lower_col = col.lower()
                        pattern = rf"\['{lower_col}'\]"
                        corrected = f"['{col}']"
                        result_value = re.sub(pattern, corrected, result_value, flags=re.IGNORECASE)

                    exec(result_value, {"df": df, "pd": pd, "plt": plt, "sns": sns})
                except Exception:
                    print("❌ Error running chart code:")
                    traceback.print_exc()
            else:
                print("❌ Unknown type returned:", result_type)

    except Exception:
        with output_query:
            print("❌ Error during LLM request:")
            traceback.print_exc()

submit_query.on_click(handle_query)

### 💬 Ask a Question About the Dataset

Textarea(value='', description='🧠 Query:', layout=Layout(height='80px', width='100%'), placeholder='Ask a ques…

Checkbox(value=True, description='Use Cleaned Dataset')

Button(button_style='primary', description='Answer Query', style=ButtonStyle())

Output()

In [None]:
import ipywidgets as widgets
from IPython.display import display, Markdown
import traceback
import json
import ast
import joblib
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from groq import Groq

# Initialize Groq client
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))

# === UI Elements ===
model_query_input = widgets.Textarea(
    placeholder='Describe the model you want to train (e.g., "Train a classification model to predict Survived from Age, Sex, and Pclass")',
    description='📊 Model Query:',
    layout=widgets.Layout(width='100%', height='100px'),
)

model_type_dropdown = widgets.Dropdown(
    options=['classification', 'regression'],
    value='classification',
    description='Model Type:',
)

submit_model_btn = widgets.Button(description="Train Model", button_style="success")
export_model_btn = widgets.Button(description="💾 Export Model", disabled=True)
model_output = widgets.Output()

# === Display UI ===
display(Markdown("### 🧠 Train a Machine Learning Model"))
display(model_query_input, model_type_dropdown, submit_model_btn, export_model_btn, model_output)

# === Handlers ===
def handle_model_training(_):
    model_output.clear_output()
    query = model_query_input.value.strip()
    if not query:
        with model_output:
            print("❗ Please enter a training request.")
        return

    model_type = model_type_dropdown.value
    df_active = df if use_cleaned_toggle.value else df_raw

    # Prepare column names
    column_names = list(df.columns)
    column_dtypes = df.dtypes.astype(str).to_dict()
    missing_perc = df.isnull().mean().round(2).to_dict()
    
    column_info_str = "\n".join(
    f"- {col} (type: {column_dtypes[col]}, missing: {missing_perc[col]*100:.0f}%)"
    for col in column_names
)
    
    prompt = f"""
You are a Python data scientist. Given the user's request:
\"\"\"{query}\"\"\"

The available columns in the dataset are:
{column_info_str}

Generate Python code to train a {model_type} model using scikit-learn. The DataFrame is `df`.

You must:
- Dynamically select target and features based on the query
- Handle categorical columns using LabelEncoder or get_dummies
- Drop or ignore datetime columns if needed
- Only use columns that exist in the dataset.
- Do NOT invent or assume any column names like "Signup_Date", "User_ID", etc.
- Before selecting features, verify they exist in the list above.
- Avoid columns with >50% missing values unless explicitly requested.
- Drop rows with missing values, especially in the target column
- Before training:
    - Check if the final dataset has at least 2 rows
    - If not, raise: `raise ValueError("Insufficient data after cleaning. Please choose different features or handle missing values.")`
- Train the model (e.g., RandomForestClassifier or LinearRegression)
- Evaluate (use accuracy or R^2 score)
- Print feature importances if supported
- Assign the trained model to a variable named `model`

Return only valid JSON using this format:
{{
  "model_type": "{model_type}",
  "code": \"\"\"<Python training code here>\"\"\"
}}
"""

    try:
        response = groq_client.chat.completions.create(
            model="llama3-70b-8192",
            messages=[
                {"role": "system", "content": "You return valid JSON with model training code."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.1
        )

        reply = response.choices[0].message.content.strip()

        try:
            parsed = json.loads(reply)
        except json.JSONDecodeError:
            try:
                parsed = ast.literal_eval(reply)
            except Exception:
                with model_output:
                    print("❌ Could not parse model code:")
                    print(reply)
                return

        model_code = parsed.get("code")

        with model_output:
            print("🚀 Training model with the following code:\n")
            print(model_code)
            print("\n===============================\n")

            local_vars = {"df": df_active, "pd": pd, "plt": plt, "sns": sns}
            exec(model_code, local_vars)

            if 'model' in local_vars:
                export_model_btn.disabled = False
                export_model_btn.model = local_vars['model']

    except Exception as e:
        with model_output:
            print("❌ Error during model generation or training:")
            traceback.print_exc()

def handle_export_model(_):
    if hasattr(export_model_btn, 'model'):
        try:
            joblib.dump(export_model_btn.model, 'trained_model.pkl')
            with model_output:
                print("✅ Model exported as 'trained_model.pkl'")
        except Exception as e:
            with model_output:
                print("❌ Failed to export model:")
                traceback.print_exc()

submit_model_btn.on_click(handle_model_training)
export_model_btn.on_click(handle_export_model)

### 🧠 Train a Machine Learning Model

Textarea(value='', description='📊 Model Query:', layout=Layout(height='100px', width='100%'), placeholder='Des…

Dropdown(description='Model Type:', options=('classification', 'regression'), value='classification')

Button(button_style='success', description='Train Model', style=ButtonStyle())

Button(description='💾 Export Model', disabled=True, style=ButtonStyle())

Output()

In [42]:
%pip install streamlit

Collecting streamlit
  Downloading streamlit-1.46.1-py3-none-any.whl.metadata (9.0 kB)
Collecting altair<6,>=4.0 (from streamlit)
  Downloading altair-5.5.0-py3-none-any.whl.metadata (11 kB)
Collecting blinker<2,>=1.5.0 (from streamlit)
  Downloading blinker-1.9.0-py3-none-any.whl.metadata (1.6 kB)
Collecting click<9,>=7.0 (from streamlit)
  Using cached click-8.2.1-py3-none-any.whl.metadata (2.5 kB)
Collecting pyarrow>=7.0 (from streamlit)
  Downloading pyarrow-20.0.0-cp313-cp313-macosx_12_0_arm64.whl.metadata (3.3 kB)
Collecting toml<2,>=0.10.1 (from streamlit)
  Downloading toml-0.10.2-py2.py3-none-any.whl.metadata (7.1 kB)
Collecting gitpython!=3.1.19,<4,>=3.0.7 (from streamlit)
  Downloading GitPython-3.1.44-py3-none-any.whl.metadata (13 kB)
Collecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.9.1-py2.py3-none-any.whl.metadata (4.1 kB)
Collecting jinja2 (from altair<6,>=4.0->streamlit)
  Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting jsons

In [None]:
# angad.sharma@zomato.com