import streamlit as st
import pandas as pd
from statsmodels.formula.api import ols
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm

# Page layout
st.set_page_config(page_title='Machine Learning(Linear Regression) App', layout='centered')

st.write("""
# Machine Learning App""")
st.subheader(""" Linear Regression Model
In this implementation, the *Ordinary Least Squares()* function is used to build a linear regression model.
""")

# Model variable
model = None

# Model building
def build_model(df, param1, param2):
    global model
    # Write the linear regression formula
    ols_formula = f"{param1} ~ {param2}"  # Assuming param1 and param2 are column names in the DataFrame df

    # Implement OLS
    OLS = ols(formula=ols_formula, data=df)

    # Fit the model to the data
    model = OLS.fit()

    # Get fitted valueshttps://github.com/Enimbuild?tab=repositories
    fitted_values = model.predict(df[param2])

    # Get summary of results
    model_summary = model.summary()

    return fitted_values, model_summary

# Function to calculate sales based on radio promotion budget, slope, and y-intercept
def calculate_sales(radio_budget, slope, y_intercept):
    sales = slope * radio_budget + y_intercept
    return sales

# Sidebar - Collects user input features into dataframe
with st.sidebar.header('1. Upload your CSV data'):
    uploaded_file = st.sidebar.file_uploader("Upload your input CSV file", type=["csv"])
    st.sidebar.markdown("[Example CSV input file](https://github.com/Enimbuild/linear_regression_model_streamlit/blob/main/data/marketing_sales_data.csv)")

    st.sidebar.title("OLS Summary Parameters")
    param1 = st.sidebar.text_input("Parameter y", value="column y")
    param2 = st.sidebar.text_input("Parameter x", value="column x")

# Main content
st.title("OLS Summary")

# Sidebar - Collects user input features into dataframe
st.sidebar.header('y = slope * x + y-intercept:')
slope = st.sidebar.number_input("Enter the slope with 2 decimal places:", value=0.00, step=0.01)
y_intercept = st.sidebar.number_input("Enter the y-intercept with 2 decimal places:", value=0.00, step=0.01)

# Input for radio promotion budget
radio_budget = st.sidebar.number_input("Enter the x value for planning:")

# Calculate sales based on the entered values
sales = calculate_sales(radio_budget, slope, y_intercept)

# Display the calculated sales
st.sidebar.write(f"Projected dependent value (y): {sales}")

# Displays the dataset
st.subheader('1. Dataset')


if uploaded_file is None:
    url = "https://raw.githubusercontent.com/Enimbuild/datasets/main/marketing_sales_data.csv"
    df= pd.read_csv(url)
    st.write(df.head(5))
    st.write("Keys:", df.keys())
    st.write("Shape:", df.shape)
    st.write("Data types:", df.dtypes)

    st.write("Check missing values:")
    st.write(df.isna().sum())

    param1="Sales"
    param2="Radio"

    #st.write("Missing values along column:", df.isna().any(axis=1))
    st.write("Missing values along column:", df.isna().any(axis=1).sum())

    st.write("Drop rows with missing values:", df.dropna(axis=0))

    st.write("Missing values:", df.isnull().sum())

     # Build and display the linear regression model
    fitted_values, model_summary = build_model(df, param1, param2)
    st.subheader('2. Model Summary')
    st.text_area('Model Summary', model_summary, height=700)

    st.subheader("3. Create a regression plot using Seaborn")
    fig1, ax = plt.subplots()
    
    sns.regplot(x=param2, y=param1, data=df, logistic=True, ci=None, ax=ax)
    st.pyplot(fig1)

    st.subheader("4. Create pairplot using Seaborn")
    st.subheader('Pairplot')
    pairplot_fig = sns.pairplot(df)
    st.pyplot(pairplot_fig)

    # Get the residuals from the model
    residuals = model.resid

    st.subheader("5. Visualize the distribution of the residuals")

    fig3, ax3 = plt.subplots()  # Create a new figure and axis object
    sns.histplot(residuals, ax=ax3)  # Plot the histogram on the axis
    ax3.set_xlabel("Residual Value")
    ax3.set_title("Histogram of Residuals")

    # Display the plot using st.pyplot()
    st.pyplot(fig3)
    st.write("Check if the distribution is a normal one.")

    st.subheader("6. Create Q-Q plot(Quantile-Quantile plot)")
    fig4, ax4 = plt.subplots()
    sm.qqplot(residuals, line='s', ax=ax4)
    ax4.set_title("Q-Q plot of Residuals")

    # Display the plot using st.pyplot()
    st.pyplot(fig4)
    st.write("Normality assumption is met when the points follow a straight diagonal line.")

    st.subheader("Check the assumption of independent observation/homoscedasticity")

    st.subheader("7. Create a scatterplot of the residuals against fitted values")
    # Create scatterplot of fitted values vs residuals
    fig5 = sns.scatterplot(x=fitted_values, y=residuals)
    fig5.axhline(0, color='red', linestyle='--')  # Add horizontal line at y=0
    fig5.set_xlabel("Fitted Values")  # Set x-axis label
    fig5.set_ylabel("Residuals")  # Set y-axis label

    # Display the plot using st.pyplot()
    st.pyplot(fig5.figure)


    
else:
    df = pd.read_csv(uploaded_file)
    st.markdown('**1.1. Glimpse of dataset**')
    st.write(df.head(5))

    st.write("Keys:", df.keys())
    st.write("Shape:", df.shape)
    st.write("Data types:", df.dtypes)

    st.write("Check missing values:")
    st.write(df.isna().sum())

    #st.write("Missing values along column:", df.isna().any(axis=1))
    st.write("Missing values along column:", df.isna().any(axis=1).sum())

    st.write("Drop rows with missing values:", df.dropna(axis=0))

    st.write("Missing values:", df.isnull().sum())


    # Build and display the linear regression model
    fitted_values, model_summary = build_model(df, param1, param2)
    st.subheader('2. Model Summary')
    st.text_area('Model Summary', model_summary, height=700)

    st.subheader("3. Create a regression plot using Seaborn")
    fig1, ax = plt.subplots()
    sns.regplot(x=param2, y=param1, data=df, logistic=True, ci=None, ax=ax)
    st.pyplot(fig1)

    st.subheader("4. Create pairplot using Seaborn")
    st.subheader('Pairplot')
    pairplot_fig = sns.pairplot(df)
    st.pyplot(pairplot_fig)

    # Get the residuals from the model
    residuals = model.resid

    st.subheader("5. Visualize the distribution of the residuals")

    fig3, ax3 = plt.subplots()  # Create a new figure and axis object
    sns.histplot(residuals, ax=ax3)  # Plot the histogram on the axis
    ax3.set_xlabel("Residual Value")
    ax3.set_title("Histogram of Residuals")

    # Display the plot using st.pyplot()
    st.pyplot(fig3)
    st.write("Check if the distribution is a normal one.")

    st.subheader("6. Create Q-Q plot(Quantile-Quantile plot)")
    fig4, ax4 = plt.subplots()
    sm.qqplot(residuals, line='s', ax=ax4)
    ax4.set_title("Q-Q plot of Residuals")

    # Display the plot using st.pyplot()
    st.pyplot(fig4)
    st.write("Normality assumption is met when the points follow a straight diagonal line.")

    st.subheader("Check the assumption of independent observation/homoscedasticity")

    st.subheader("7. Create a scatterplot of the residuals against fitted values")
    # Create scatterplot of fitted values vs residuals
    fig5 = sns.scatterplot(x=fitted_values, y=residuals)
    fig5.axhline(0, color='red', linestyle='--')  # Add horizontal line at y=0
    fig5.set_xlabel("Fitted Values")  # Set x-axis label
    fig5.set_ylabel("Residuals")  # Set y-axis label

    # Display the plot using st.pyplot()
    st.pyplot(fig5.figure)
