In [None]:
import streamlit as st
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# --- Helper Functions ---
def load_sample_data():
    return pd.read_csv("sample_crop_yield_data.csv")

def clean_data(df):
    df = df.dropna()
    return df

def preprocess_data(df):
    cat_cols = df.select_dtypes(include=['object']).columns
    label_encoders = {}
    for col in cat_cols:
        le = LabelEncoder()
        df[col] = le.fit_transform(df[col])
        label_encoders[col] = le
    scaler = StandardScaler()
    df[df.columns] = scaler.fit_transform(df)
    return df, label_encoders

def visualize_corr(df):
    fig, ax = plt.subplots()
    sns.heatmap(df.corr(), annot=True, cmap='coolwarm', ax=ax)
    st.pyplot(fig)

def train_model(X, y, model_name):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    if model_name == 'Linear Regression':
        model = LinearRegression()
    else:
        model = RandomForestRegressor()
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    return model, y_test, y_pred

def display_metrics(y_test, y_pred):
    st.write("**Mean Absolute Error (MAE):**", mean_absolute_error(y_test, y_pred))
    st.write("**Mean Squared Error (MSE):**", mean_squared_error(y_test, y_pred))
    st.write("**R² Score:**", r2_score(y_test, y_pred))

# --- Streamlit UI ---
st.set_page_config(page_title="SmartCropAI - Crop Yield Prediction", layout="wide")
st.title("🌾 SmartCropAI - Crop Yield Prediction")

menu = ["Home", "Data Upload", "EDA", "Cleaning", "Preprocessing", "Visualization", "Modeling", "Prediction"]
choice = st.sidebar.selectbox("Select Page", menu)

if "df" not in st.session_state:
    st.session_state.df = None

# --- Home Page ---
if choice == "Home":
    st.subheader("Welcome to SmartCropAI")
    st.write("This app helps farmers, organizations, and governments predict crop yield using machine learning.")

# --- Data Upload ---
elif choice == "Data Upload":
    st.subheader("Upload Your Dataset or Use Sample")
    file = st.file_uploader("Upload CSV File", type=["csv"])
    use_sample = st.checkbox("Use Sample Dataset")

    if file:
        df = pd.read_csv(file)
        st.session_state.df = df
        st.success("Data uploaded successfully!")
        st.dataframe(df.head())
    elif use_sample:
        df = load_sample_data()
        st.session_state.df = df
        st.success("Sample data loaded!")
        st.dataframe(df.head())

# --- EDA ---
elif choice == "EDA":
    st.subheader("Exploratory Data Analysis")
    if st.session_state.df is not None:
        df = st.session_state.df.copy()
        st.write("**Data Shape:**", df.shape)
        st.write("**Missing Values:**")
        st.dataframe(df.isnull().sum())
        st.write("**Data Types:**")
        st.dataframe(df.dtypes)
        st.write("**Categorical Columns Distribution:**")
        cat_cols = df.select_dtypes(include='object').columns
        for col in cat_cols:
            st.bar_chart(df[col].value_counts())
    else:
        st.warning("Upload data first.")

# --- Cleaning ---
elif choice == "Cleaning":
    st.subheader("Data Cleaning")
    if st.session_state.df is not None:
        df = clean_data(st.session_state.df.copy())
        st.session_state.df = df
        st.success("Data cleaned (missing values dropped).")
        st.dataframe(df.head())
    else:
        st.warning("Upload data first.")

# --- Preprocessing ---
elif choice == "Preprocessing":
    st.subheader("Data Preprocessing")
    if st.session_state.df is not None:
        df, encoders = preprocess_data(st.session_state.df.copy())
        st.session_state.df = df
        st.success("Data encoded and scaled.")
        st.dataframe(df.head())
    else:
        st.warning("Upload and clean data first.")

# --- Visualization ---
elif choice == "Visualization":
    st.subheader("Correlation Matrix & Visuals")
    if st.session_state.df is not None:
        df = st.session_state.df.copy()
        visualize_corr(df)
    else:
        st.warning("Please preprocess data first.")

# --- Modeling ---
elif choice == "Modeling":
    st.subheader("Model Training & Evaluation")
    if st.session_state.df is not None:
        df = st.session_state.df.copy()
        target = st.selectbox("Select target column", df.columns)
        if target:
            X = df.drop(columns=[target])
            y = df[target]
            model_type = st.selectbox("Select model", ["Linear Regression", "Random Forest"])
            model, y_test, y_pred = train_model(X, y, model_type)
            display_metrics(y_test, y_pred)
    else:
        st.warning("Please preprocess data first.")

# --- Prediction ---
elif choice == "Prediction":
    st.subheader("Predict Yield with Trained Model")
    st.write("(In progress: implement live prediction input fields based on model features.)")
