In [6]:
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time  

st.set_page_config(layout="wide")

st.title("Simplified AI: Enhanced EDA Tool with Outlier Removal")

# Custom CSS for better UI
st.markdown("""
<style>
    .stButton>button {
        transition: all 0.3s ease;
        border: none;
        background-color: #008CBA;
        color: white;
        padding: 10px 20px;
        text-align: center;
        text-decoration: none;
        display: inline-block;
        font-size: 16px;
        margin: 4px 2px;
        cursor: pointer;
        border-radius: 12px;
    }
    .stButton>button:hover {
        background-color: #007B9A;
        transform: scale(1.05);
    }
</style>
""", unsafe_allow_html=True)

# Function to summarize data
def summarize_data(df):
    return {
        "Shape": df.shape,
        "Columns": list(df.columns),
        "Missing Values": df.isnull().sum().to_dict(),
        "Summary Statistics": df.describe().to_dict()
    }

# Function to plot correlation heatmap
def plot_correlation_heatmap(df):
    plt.figure(figsize=(10, 6))
    sns.heatmap(df.corr(), annot=True, cmap="coolwarm")
    plt.title("Correlation Heatmap")
    st.pyplot(plt)

# Function to detect outliers using IQR
def detect_outliers(df, column):
    q1 = df[column].quantile(0.25)
    q3 = df[column].quantile(0.75)
    iqr = q3 - q1
    lower_bound = q1 - 1.5 * iqr
    upper_bound = q3 + 1.5 * iqr
    return df[(df[column] < lower_bound) | (df[column] > upper_bound)]

# remove outliers
def remove_outliers(df, column):
    outliers = detect_outliers(df, column)
    return df.drop(outliers.index)

# Function to visualize data distribution
def visualize_distribution(df, column):
    plt.figure(figsize=(8, 5))
    sns.histplot(df[column], kde=True, bins=30, color="blue")
    plt.title(f"Distribution Plot for {column}")
    st.pyplot(plt)

# Function for box plot
def plot_boxplot(df, column):
    plt.figure(figsize=(8, 5))
    sns.boxplot(y=df[column])
    plt.title(f"Box Plot for {column}")
    st.pyplot(plt)

# Function for scatter plot
def plot_scatter(df, x_column, y_column):
    plt.figure(figsize=(8, 5))
    sns.scatterplot(data=df, x=x_column, y=y_column)
    plt.title(f"Scatter Plot: {x_column} vs {y_column}")
    st.pyplot(plt)

# File uploader with animation
uploaded_file = st.file_uploader("Upload your dataset (CSV/Excel)", type=["csv", "xlsx"])

# Session state to store results
if "results" not in st.session_state:
    st.session_state.results = []
if "df" not in st.session_state:
    st.session_state.df = None

if uploaded_file:
    # Load dataset with animation
    loading_placeholder = st.empty()
    loading_placeholder.text("Loading dataset...")
    time.sleep(2)  # Simulate loading time
    if uploaded_file.name.endswith(".csv"):
        st.session_state.df = pd.read_csv(uploaded_file)
    else:
        st.session_state.df = pd.read_excel(uploaded_file)
    loading_placeholder.empty()

    st.subheader("Dataset Overview")
    col1, col2 = st.columns(2)
    with col1:
        st.write(f"Shape: {st.session_state.df.shape}")
        st.dataframe(st.session_state.df.head())
    with col2:
        st.write("**Column Types:**")
        st.write(st.session_state.df.dtypes)

    st.subheader("Summary Statistics")
    summary = summarize_data(st.session_state.df)
    for key, value in summary.items():
        st.write(f"**{key}:**")
        if isinstance(value, dict):
            st.write(pd.DataFrame.from_dict({k: [v] for k, v in value.items()}, orient='columns'))
        else:
            st.write(value)

    st.subheader("Correlation Heatmap")
    plot_correlation_heatmap(st.session_state.df)

    # Numerical column selection
    numerical_columns = st.session_state.df.select_dtypes(include=[np.number]).columns.tolist()
    categorical_columns = st.session_state.df.select_dtypes(include=['object', 'category']).columns.tolist()

    if numerical_columns:
        with st.expander("Numerical Data Analysis"):
            selected_column = st.selectbox("Choose a numerical column for analysis", numerical_columns)
            if selected_column:
                col1, col2 = st.columns(2)
                with col1:
                    plot_type = st.radio("Select Plot Type", ["Distribution Plot", "Box Plot"])
                    if plot_type == "Distribution Plot":
                        visualize_distribution(st.session_state.df, selected_column)
                    else:
                        plot_boxplot(st.session_state.df, selected_column)
                with col2:
                    st.write("Outlier Detection")
                    outliers = detect_outliers(st.session_state.df, selected_column)
                    st.write(f"Number of outliers in `{selected_column}`: {len(outliers)}")
                    st.dataframe(outliers)
                    
                    if st.button("Remove Outliers"):
                        st.session_state.df = remove_outliers(st.session_state.df, selected_column)
                        st.write("Outliers removed from the dataset.")
                        st.write(f"New shape of dataset: {st.session_state.df.shape}")
                        
                    if st.button("Save Outlier Analysis"):
                        st.session_state.results.append({
                            "analysis_type": "Outlier",
                            "column": selected_column,
                            "outlier_count": len(outliers)
                        })
                        st.success(f"Outlier analysis for '{selected_column}' saved!")

    # Bivariate Analysis
    if len(numerical_columns) > 1:
        with st.expander("Bivariate Analysis"):
            x_col = st.selectbox("X-Axis", numerical_columns)
            y_col = st.selectbox("Y-Axis", [col for col in numerical_columns if col != x_col])
            plot_scatter(st.session_state.df, x_col, y_col)
            if st.button("Save Bivariate Analysis"):
                st.session_state.results.append({
                    "analysis_type": "Bivariate",
                    "x_axis": x_col,
                    "y_axis": y_col
                })
                st.success(f"Bivariate analysis for '{x_col}' vs '{y_col}' saved!")

    # Categorical data analysis
    if categorical_columns:
        with st.expander("Categorical Data Analysis"):
            cat_col = st.selectbox("Select a categorical column", categorical_columns)
            plt.figure(figsize=(10, 6))
            sns.countplot(data=st.session_state.df, x=cat_col)
            plt.title(f"Bar Plot for {cat_col}")
            st.pyplot(plt)
            if st.button("Save Categorical Analysis"):
                st.session_state.results.append({
                    "analysis_type": "Categorical",
                    "column": cat_col
                })
                st.success(f"Categorical analysis for '{cat_col}' saved!")

    # Display saved results
    with st.expander("View Saved Analysis"):
        if st.session_state.results:
            for result in st.session_state.results:
                st.write(f"**{result['analysis_type']} Analysis:**")
                if result['analysis_type'] == "Outlier":
                    st.write(f"- Column: {result['column']}")
                    st.write(f"- Outliers: {result['outlier_count']}")
                elif result['analysis_type'] == "Bivariate":
                    st.write(f"- X-Axis: {result['x_axis']}")
                    st.write(f"- Y-Axis: {result['y_axis']}")
                elif result['analysis_type'] == "Categorical":
                    st.write(f"- Column: {result['column']}")
        else:
            st.write("No analysis saved yet.")