In [1]:
!pip install streamlit ngrok pandas matplotlib seaborn langchain-google-genai langchain-experimental



In [2]:
!pip install pyngrok



In [5]:
%%writefile main.py
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_experimental.agents import create_csv_agent
import io
import re
import json

# Set page config at the very beginning
st.set_page_config(layout="wide", page_title="Data Insight Explorer")

# Function to load and preprocess data
def load_and_preprocess_data(file):
    df = pd.read_csv(file)
    df = df.dropna(subset=[col for col in df.select_dtypes(include=['number']).columns])
    return df

# Function to display data summary
def display_data_summary(df):
    st.subheader("First 20 Rows")
    st.write(df.head(20))
    st.subheader("Data Summary")
    st.write(df.describe())

# Function for statistical analysis
def statistical_analysis(df, agent):
    st.subheader("Numeric Column Statistics")
    numeric_cols = df.select_dtypes(include=['number']).columns
    stats = df[numeric_cols].agg(['mean', 'median', 'std'])
    st.write(stats)

    st.subheader("Mode (Most Frequent Value)")
    mode = df.mode().iloc[0]
    st.write(mode)

    st.subheader("Correlation Heatmap")
    corr = df[numeric_cols].corr()
    plt.figure(figsize=(10, 8))
    sns.heatmap(corr, annot=True, cmap='coolwarm')
    st.pyplot(plt)

    corr_str = corr.to_string()

    plt.close()

    st.subheader("Categorical Column Analysis")
    cat_cols = df.select_dtypes(include=['object']).columns
    for col in cat_cols:
        st.write(f"Unique values in {col}:", df[col].nunique())
        st.write(df[col].value_counts())

def clean_json_string(json_string):
    json_match = re.search(r'\{[\s\S]*\}', json_string)
    if json_match:
        json_str = json_match.group(0)
        try:
            json_data = json.loads(json_str)
            return json.dumps(json_data)
        except json.JSONDecodeError:
            return None
    return None

def generate_plots(agent, df, num_plots=5):
    prompt = f"""Analyze the given dataset and suggest {num_plots} most informative and relevant plots. For each plot, provide:
    1. A title for the plot
    2. Python code to generate the plot using matplotlib and seaborn
    3. A brief explanation of what the plot shows and why it's informative

    Return your response as a JSON string with the following structure:
    {{
        "plots": [
            {{
                "title": "Plot title",
                "code": "Python code to generate the plot",
                "explanation": "Brief explanation of the plot"
            }},
            ...
        ]
    }}
    IMPORTANT: Your response should only contain the JSON string, nothing else."""

    response = agent.run(prompt)

    cleaned_json = clean_json_string(response)

    if cleaned_json:
        try:
            plot_data = json.loads(cleaned_json)
            return plot_data
        except json.JSONDecodeError as e:
            st.error(f"Failed to parse the cleaned JSON. Error: {str(e)}")
    else:
        st.error("Failed to extract valid JSON from the agent's response.")

    return None

def display_insights(insights, df):
    for i, plot in enumerate(insights['plots'], 1):
        st.subheader(f"Plot {i}: {plot['title']}")
        fig_col1, fig_col2 = st.columns([3, 1])
        with fig_col1:
            st.code(plot['code'], language='python')
            try:
                exec(plot['code'], globals(), {'df': df, 'plt': plt, 'sns': sns})
                st.pyplot(plt.gcf())
                plt.close()
            except Exception as e:
                st.error(f"An error occurred while generating plot {i}: {str(e)}")
        with fig_col2:
            st.write("*Explanation:*", plot['explanation'])

def create_custom_plot(df):
    numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
    x_col = st.selectbox("Select X-axis column", numeric_cols)
    y_col = st.selectbox("Select Y-axis column", numeric_cols)

    plot_types = ["Scatter", "Line", "Bar", "Box", "Violin", "Histogram"]
    plot_type = st.selectbox("Select plot type", plot_types)

    if st.button("Create Custom Plot"):
        fig, ax = plt.subplots(figsize=(10, 6))

        if plot_type == "Scatter":
            sns.scatterplot(data=df, x=x_col, y=y_col, ax=ax)
        elif plot_type == "Line":
            sns.lineplot(data=df, x=x_col, y=y_col, ax=ax)
        elif plot_type == "Bar":
            sns.barplot(data=df, x=x_col, y=y_col, ax=ax)
        elif plot_type == "Box":
            sns.boxplot(data=df, x=x_col, y=y_col, ax=ax)
        elif plot_type == "Violin":
            sns.violinplot(data=df, x=x_col, y=y_col, ax=ax)
        elif plot_type == "Histogram":
            sns.histplot(data=df, x=x_col, ax=ax)

        plt.title(f"{plot_type} Plot: {y_col} vs {x_col}")
        plt.xlabel(x_col)
        plt.ylabel(y_col)

        st.session_state.custom_plot = {
            'type': plot_type,
            'x': x_col,
            'y': y_col,
            'fig': fig
        }

        plt.close()

def main():
    st.title("Querying CSVs and Plot Graphs with LLMs")

    st.markdown(
        """
        <div style='background-color: #f0f2f6; padding: 10px; border-radius: 5px; font-size: 0.8em;'>
        This app was developed for Queerying CSVs and Graph pllots using LangChain(LLM) and the Gemini 1.5 Flash model.
        Right now, it's running on the free version of the Gemini API.
        Upgrading to a more advanced model can improve the quality of responses from the language model and provide even better insights.
        </div>
        """,
        unsafe_allow_html=True
    )

    main_container = st.container()
    col1, col2 = st.columns([1, 4])

    with col1:
        st.header("📁 Data Upload")
        uploaded_file = st.file_uploader("Choose a CSV file", type="csv")

    if uploaded_file is not None:
        with main_container:
            df = load_and_preprocess_data(uploaded_file)

            model_name = "models/gemini-1.5-pro"
            chat_model = ChatGoogleGenerativeAI(model=model_name, temperature=0)

            csv_file = io.StringIO(df.to_csv(index=False))
            agent = create_csv_agent(chat_model, csv_file, verbose=True, allow_dangerous_code=True)

            with col2:
                data_summary_container = st.expander("Data Summary and Statistical Analysis", expanded=True)
                with data_summary_container:
                    display_data_summary(df)
                    statistical_analysis(df, agent)

                tab1, tab2, tab3 = st.tabs(["Auto-generate Insights by the LLM", "Custom Visualization", "Data Q&A"])

                with tab1:
                    st.header("🖼 Auto-generated Insights")
                    num_plots = st.slider("Number of plots to generate", min_value=1, max_value=10, value=5)
                    if st.button("Generate Insights"):
                        with st.spinner("Generating insights..."):
                            insights = generate_plots(agent, df, num_plots)
                            st.session_state.insights = insights

                    if 'insights' in st.session_state:
                        display_insights(st.session_state.insights, df)

                with tab2:
                    st.header("Custom Visualization")
                    create_custom_plot(df)

                    if 'custom_plot' in st.session_state:
                        fig_col1, fig_col2 = st.columns([3, 1])
                        with fig_col1:
                            st.pyplot(st.session_state.custom_plot['fig'])
                        with fig_col2:
                            custom_plot_explanation = agent.run(f"Analyze the {st.session_state.custom_plot['type']} plot of {st.session_state.custom_plot['y']} vs {st.session_state.custom_plot['x']} and provide insights.")
                            st.write(custom_plot_explanation)

                with tab3:
                    st.header("Data Q&A")
                    question = st.text_input("Ask a question about the data:")
                    if st.button("Submit Question"):
                        response = agent.run(question)
                        st.write(response)

if __name__ == "__main__":
    main()


Writing main.py


In [None]:
from pyngrok import ngrok

# Get your authtoken from https://dashboard.ngrok.com/auth
NGROK_AUTH_TOKEN = "2l4EdtwdepQhStwmwLZfIUNbnnG_6QmcKwmgNsNuXSRi7wR55"  # Replace with your actual authtoken
ngrok.set_auth_token(NGROK_AUTH_TOKEN)

# Run Streamlit app in the background
!streamlit run main.py &

# Expose the port that Streamlit is running on
public_url = ngrok.connect(8501)  # Specify the port as an integer
print(f"Streamlit App is available on: {public_url}")


Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8501[0m
[34m  Network URL: [0m[1mhttp://172.28.0.12:8501[0m
[34m  External URL: [0m[1mhttp://34.48.8.246:8501[0m
[0m
