In [None]:
import streamlit as st
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import plotly.express as px
import plotly.graph_objects as go
from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.cluster import KMeans
import os
import pickle

# Set page config - MUST be first Streamlit command
st.set_page_config(page_title="SKANEM FORECASTING", layout="wide")

# --------------------------
# App Configuration
# --------------------------

# Logo and title
col1, col2 = st.columns([0.1, 0.9])
with col1:
    st.image("https://via.placeholder.com/44", width=44)  # Replace with your logo path
with col2:
    st.title("Advanced Supply Chain Forecasting")

st.image("C:/Users/chris.mutuku/OneDrive - Skanem AS/Desktop/logo.jpg", width=44)
st.title("Advanced Supply Chain Forecasting")
# --------------------------
# Data Management
# --------------------------
DATA_DIR = "forecast_data"
os.makedirs(DATA_DIR, exist_ok=True)

def save_material_data(material_name, data):
    path = os.path.join(DATA_DIR, f"{material_name.replace(' ', '_')}.pkl")
    with open(path, 'wb') as f:
        pickle.dump(data, f)

def load_material_data(material_name):
    path = os.path.join(DATA_DIR, f"{material_name.replace(' ', '_')}.pkl")
    if os.path.exists(path):
        with open(path, 'rb') as f:
            return pickle.load(f)
    return None

def get_saved_materials():
    return [f.replace('.pkl', '').replace('_', ' ') for f in os.listdir(DATA_DIR) if f.endswith('.pkl')]

# --------------------------
# Forecasting Models
# --------------------------
def calculate_metrics(actual, predicted):
    return {
        'RMSE': np.sqrt(mean_squared_error(actual, predicted)),
        'MAPE': mean_absolute_percentage_error(actual, predicted) * 100,
        'R2': r2_score(actual, predicted)
    }

def generate_forecast(current_balance, avg_consumption, variability, horizon):
    np.random.seed(42)
    dates = pd.date_range(datetime.now(), periods=horizon)
    
    # Deterministic forecast
    deterministic = [max(0, current_balance - (i * avg_consumption)) for i in range(horizon)]
    
    # Probabilistic forecast
    daily_variation = 1 + (np.random.rand(horizon) - 0.5) * (variability/100)
    probabilistic = [max(0, current_balance - np.sum(avg_consumption * daily_variation[:i+1])) for i in range(horizon)]
    
    return dates, deterministic, probabilistic

def train_supervised_model(X, y):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=False)
    model = RandomForestRegressor(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)
    return model, X_test, y_test

def apply_unsupervised_learning(data, n_clusters=3):
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = kmeans.fit_predict(data)
    return clusters, kmeans

# --------------------------
# Page Functions
# --------------------------
def dashboard_page():
    st.title("📊 Dashboard")
    st.write("Welcome to the dashboard.")

def monthly_view():
    st.title("📅 Monthly View")
    st.write("Monthly forecasting analysis here.")

def model_performance():
    st.title("📈 Model Performance")
    st.write("Model metrics and evaluation.")

def ml_insights():
    st.title("🤖 Machine Learning Insights")
    st.write("Results from ML modeling.")

# --------------------------
# UI Components
# --------------------------
# Initialize session states
if 'current_stocks' not in st.session_state:
    st.session_state.current_stocks = pd.DataFrame()
if 'historical_data' not in st.session_state:
    st.session_state.historical_data = pd.DataFrame()
if 'ml_models' not in st.session_state:
    st.session_state.ml_models = {}

# Sidebar navigation
page = st.sidebar.radio("Navigation", ["Dashboard", "Monthly View", "Model Performance", "ML Insights", "Advanced Forecasting"])

# Page routing
if page == "Dashboard":
    dashboard_page()
elif page == "Monthly View":
    monthly_view()
elif page == "Model Performance":
    model_performance()
elif page == "ML Insights":
    ml_insights()
elif page == "Advanced Forecasting":
    # --------------------------
    # Advanced Forecasting UI
    # --------------------------
    with st.sidebar:
        st.header("⚙️ Configuration")
        
        # Current Inventory Upload
        with st.expander("📥 Upload Current Inventory", expanded=True):
            uploaded_stocks = st.file_uploader(
                "Upload current inventory (CSV)", 
                type=['csv'],
                help="Upload CSV with columns: 'Item Description' and 'Quantity In Sqr Meters'"
            )
            
            if uploaded_stocks is not None:
                try:
                    st.session_state.current_stocks = pd.read_csv(uploaded_stocks)
                    st.success(f"Uploaded {len(st.session_state.current_stocks)} records")
                    
                    # Validate columns
                    required_cols = {'Item Description', 'Quantity In Sqr Meters'}
                    if not required_cols.issubset(st.session_state.current_stocks.columns):
                        missing = required_cols - set(st.session_state.current_stocks.columns)
                        st.error(f"Missing columns: {', '.join(missing)}")
                    else:
                        st.dataframe(st.session_state.current_stocks.head(3))
                except Exception as e:
                    st.error(f"Error reading file: {str(e)}")
        
        # Material selection
        saved_materials = get_saved_materials()
        material_list = ["Create New"] + saved_materials
        
        # Add materials from uploaded inventory
        if not st.session_state.current_stocks.empty and 'Item Description' in st.session_state.current_stocks.columns:
            uploaded_materials = st.session_state.current_stocks['Item Description'].unique().tolist()
            material_list = uploaded_materials + [m for m in saved_materials if m not in uploaded_materials]
        
        material_option = st.selectbox("Select material", material_list)
        
        if material_option == "Create New":
            material_name = st.text_input("New Material Name", "White BOPP 38 Micron Film")
        else:
            material_name = material_option
            loaded_data = load_material_data(material_name)
        
        # Auto-populate current balance
        default_balance = 1000.0
        if not st.session_state.current_stocks.empty and 'Item Description' in st.session_state.current_stocks.columns:
            if material_name in st.session_state.current_stocks['Item Description'].values:
                default_balance = float(st.session_state.current_stocks.loc[
                    st.session_state.current_stocks['Item Description'] == material_name, 
                    'Quantity In Sqr Meters'
                ].values[0])
                st.info(f"Current stock: {default_balance:,.2f} sqm")
        elif material_option != "Create New":
            default_balance = loaded_data['current_balance']
        
        # Forecasting parameters
        current_balance = st.number_input("Current Balance (sqm)", min_value=0.0, value=default_balance)
        avg_consumption = st.number_input("Avg Daily Consumption (sqm)", min_value=0.0, value=50.0 if material_option == "Create New" else loaded_data['avg_consumption'])
        variability = st.slider("Consumption Variability (%)", 0, 50, 10 if material_option == "Create New" else loaded_data['variability'])
        safety_stock = st.number_input("Safety Stock (sqm)", min_value=0.0, value=200.0 if material_option == "Create New" else loaded_data['safety_stock'])
        lead_time = st.number_input("Lead Time (days)", min_value=1, value=7 if material_option == "Create New" else loaded_data['lead_time'])
        forecast_horizon = st.selectbox("Forecast Horizon", ["30 days", "60 days", "90 days"], index=0)
        
        if st.button("💾 Save Material Configuration"):
            data = {
                'current_balance': current_balance,
                'avg_consumption': avg_consumption,
                'variability': variability,
                'safety_stock': safety_stock,
                'lead_time': lead_time
            }
            save_material_data(material_name, data)
            st.success(f"Saved {material_name} configuration!")
        
        st.divider()
        
        # Historical Data Upload
        with st.expander("📤 Upload Historical Data"):
            uploaded_file = st.file_uploader(
                "Upload consumption history (CSV)", 
                type=['csv'],
                help="Upload historical data with Date and Quantity In Sqr Meters columns"
            )
            if uploaded_file:
                try:
                    st.session_state.historical_data = pd.read_csv(uploaded_file)
                    
                    # Validate historical data columns
                    hist_cols = set(st.session_state.historical_data.columns)
                    if not {'Date', 'Quantity In Sqr Meters'}.issubset(hist_cols):
                        st.error("Historical data must contain 'Date' and 'Quantity In Sqr Meters' columns")
                    else:
                        st.success(f"Uploaded {len(st.session_state.historical_data)} historical records")
                        st.dataframe(st.session_state.historical_data.head(3))
                        
                        # Train ML models when historical data is uploaded
                        with st.spinner("Training machine learning models..."):
                            try:
                                # Prepare data for supervised learning
                                historical = st.session_state.historical_data.copy()
                                historical['Date'] = pd.to_datetime(historical['Date'])
                                historical = historical.sort_values('Date')
                                
                                # Feature engineering
                                historical['day_of_week'] = historical['Date'].dt.dayofweek
                                historical['day_of_month'] = historical['Date'].dt.day
                                historical['month'] = historical['Date'].dt.month
                                
                                X = historical[['day_of_week', 'day_of_month', 'month']]
                                y = historical['Quantity In Sqr Meters']
                                
                                # Train supervised model
                                model, X_test, y_test = train_supervised_model(X, y)
                                supervised_pred = model.predict(X_test)
                                
                                # Apply unsupervised learning
                                clusters, kmeans = apply_unsupervised_learning(historical[['Quantity In Sqr Meters']].values)
                                
                                # Store models in session state
                                st.session_state.ml_models = {
                                    'supervised': model,
                                    'supervised_metrics': calculate_metrics(y_test, supervised_pred),
                                    'clusters': clusters,
                                    'kmeans': kmeans,
                                    'historical': historical
                                }
                                
                                st.success("Machine learning models trained successfully!")
                            except Exception as e:
                                st.error(f"Error training models: {str(e)}")
                except Exception as e:
                    st.error(f"Error reading file: {str(e)}")

    # --------------------------
    # Forecasting Logic
    # --------------------------
    horizon_days = int(forecast_horizon.split(" ")[0])
    dates, deterministic, probabilistic = generate_forecast(current_balance, avg_consumption, variability, horizon_days)

    # Create DataFrame
    df = pd.DataFrame({
        'Date': dates,
        'Deterministic': deterministic,
        'Probabilistic': probabilistic,
        'Reorder_Point': safety_stock + (lead_time * avg_consumption),
        'Safety_Stock': safety_stock
    })

    # Train-test split (time-series aware)
    split_idx = int(horizon_days * 0.7)
    train = df.iloc[:split_idx]
    test = df.iloc[split_idx:]

    # Calculate metrics
    metrics = calculate_metrics(test['Deterministic'], test['Probabilistic'])

    # --------------------------
    # Dashboard Layout
    # --------------------------
    tab1, tab2, tab3, tab4 = st.tabs(["📈 Forecast Dashboard", "📆 Monthly View", "📊 Model Performance", "🤖 ML Insights"])

    with tab1:
        col1, col2 = st.columns([3, 1])
        
        with col1:
            fig = go.Figure()
            fig.add_trace(go.Scatter(
                x=df['Date'], y=df['Deterministic'], 
                name='Deterministic Forecast',
                line=dict(color='blue'),
                hovertemplate='Date: %{x}<br>Quantity: %{y:.2f} sqm<extra></extra>'
            ))
            fig.add_trace(go.Scatter(
                x=df['Date'], y=df['Probabilistic'], 
                name='Probabilistic Forecast',
                line=dict(color='green', dash='dot'),
                hovertemplate='Date: %{x}<br>Quantity: %{y:.2f} sqm<extra></extra>'
            ))
            fig.add_hline(
                y=df['Reorder_Point'].iloc[0], 
                line_dash='dot', 
                line_color='orange', 
                name='Reorder Point',
                annotation_text=f"Reorder Point: {df['Reorder_Point'].iloc[0]:.2f} sqm"
            )
            fig.add_hline(
                y=df['Safety_Stock'].iloc[0], 
                line_dash='dot', 
                line_color='red', 
                name='Safety Stock',
                annotation_text=f"Safety Stock: {df['Safety_Stock'].iloc[0]:.2f} sqm"
            )
            fig.update_layout(
                title=f"{material_name} Forecast (Sqm)",
                xaxis_title='Date',
                yaxis_title='Quantity (sqm)',
                hovermode='x unified',
                legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
            )
            st.plotly_chart(fig, use_container_width=True)
        
        with col2:
            st.metric("Current Balance", f"{current_balance:,.2f} sqm")
            days_remaining = int(current_balance / avg_consumption)
            st.metric("Days Until Stockout", days_remaining)
            st.metric("Reorder Point", f"{df['Reorder_Point'].iloc[0]:,.2f} sqm")
            st.metric("Avg Daily Use", f"{avg_consumption:,.2f} sqm/day")
            
            # Stock alerts
            if days_remaining < lead_time:
                st.error("⚠️ Warning: Stock may run out before replenishment!")
            elif days_remaining < lead_time * 1.5:
                st.warning("⚠️ Alert: Consider reordering soon")

    with tab2:
        df_monthly = df.set_index('Date').resample('M').agg({
            'Deterministic': 'min',
            'Probabilistic': 'min',
            'Reorder_Point': 'first',
            'Safety_Stock': 'first'
        }).reset_index()
        
        fig_month = px.line(
            df_monthly, 
            x='Date', 
            y=['Deterministic', 'Probabilistic', 'Reorder_Point', 'Safety_Stock'],
            title="Monthly Forecast Summary (Sqm)",
            labels={'value': 'Quantity (sqm)', 'variable': 'Metric'},
            hover_data={'value': ':.2f'}
        )
        fig_month.update_traces(hovertemplate='Date: %{x}<br>Quantity: %{y:.2f} sqm<extra></extra>')
        st.plotly_chart(fig_month, use_container_width=True)
        
        styled_df = df_monthly.style.format({
            'Date': lambda x: x.strftime('%Y-%m'),
            'Deterministic': '{:,.2f}',
            'Probabilistic': '{:,.2f}',
            'Reorder_Point': '{:,.2f}',
            'Safety_Stock': '{:,.2f}'
        })
        st.dataframe(styled_df, use_container_width=True)

    with tab3:
        st.subheader("Model Performance (70% train / 30% test)")
        
        col1, col2, col3 = st.columns(3)
        col1.metric("RMSE", f"{metrics['RMSE']:.2f} sqm", 
                   help="Root Mean Square Error - Lower values indicate better fit")
        col2.metric("MAPE", f"{metrics['MAPE']:.2f}%", 
                   help="Mean Absolute Percentage Error - Percentage error in predictions")
        col3.metric("R² Score", f"{metrics['R2']:.2f}", 
                   help="Coefficient of Determination - 1.0 is perfect prediction")
        
        fig_test = go.Figure()
        fig_test.add_trace(go.Scatter(
            x=test['Date'], y=test['Deterministic'], 
            name='Actual',
            mode='lines+markers',
            hovertemplate='Date: %{x}<br>Actual: %{y:.2f} sqm<extra></extra>'
        ))
        fig_test.add_trace(go.Scatter(
            x=test['Date'], y=test['Probabilistic'], 
            name='Predicted',
            mode='lines+markers',
            hovertemplate='Date: %{x}<br>Predicted: %{y:.2f} sqm<extra></extra>'
        ))
        fig_test.update_layout(
            title="Test Set: Actual vs Predicted (Sqm)",
            xaxis_title='Date',
            yaxis_title='Quantity (sqm)',
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
        )
        st.plotly_chart(fig_test, use_container_width=True)
        
        residuals = test['Deterministic'] - test['Probabilistic']
        fig_residuals = go.Figure()
        fig_residuals.add_trace(go.Scatter(
            x=test['Date'], y=residuals,
            mode='markers',
            name='Residuals',
            hovertemplate='Date: %{x}<br>Residual: %{y:.2f} sqm<extra></extra>'
        ))
        fig_residuals.add_hline(y=0, line_dash='dash', line_color='red')
        fig_residuals.update_layout(
            title="Residual Plot (Actual - Predicted)",
            xaxis_title='Date',
            yaxis_title='Residual (sqm)'
        )
        st.plotly_chart(fig_residuals, use_container_width=True)

    with tab4:
        if st.session_state.ml_models:
            st.subheader("Supervised Learning Insights")
            
            col1, col2 = st.columns(2)
            with col1:
                st.metric("Supervised Model RMSE", 
                         f"{st.session_state.ml_models['supervised_metrics']['RMSE']:.2f} sqm",
                         help="Root Mean Square Error of Random Forest model")
                st.metric("Supervised Model R²", 
                         f"{st.session_state.ml_models['supervised_metrics']['R2']:.2f}",
                         help="Coefficient of Determination of Random Forest model")
            
            with col2:
                feature_importance = pd.DataFrame({
                    'Feature': ['Day of Week', 'Day of Month', 'Month'],
                    'Importance': st.session_state.ml_models['supervised'].feature_importances_
                })
                fig_importance = px.bar(
                    feature_importance, 
                    x='Feature', 
                    y='Importance',
                    title="Feature Importance",
                    labels={'Importance': 'Relative Importance'}
                )
                st.plotly_chart(fig_importance, use_container_width=True)
            
            st.subheader("Unsupervised Learning Clusters")
            
            historical = st.session_state.ml_models['historical'].copy()
            historical['Cluster'] = st.session_state.ml_models['clusters']
            cluster_stats = historical.groupby('Cluster')['Quantity In Sqr Meters'].describe()
            
            col1, col2 = st.columns(2)
            with col1:
                fig_clusters = px.scatter(
                    historical,
                    x='Date',
                    y='Quantity In Sqr Meters',
                    color='Cluster',
                    title="Consumption Patterns Clusters",
                    hover_data=['Quantity In Sqr Meters'],
                    labels={'Quantity In Sqr Meters': 'Quantity (sqm)'}
                )
                st.plotly_chart(fig_clusters, use_container_width=True)
            
            with col2:
                st.write("Cluster Statistics:")
                st.dataframe(cluster_stats.style.format('{:,.2f}'))
            
            centroids = pd.DataFrame(
                st.session_state.ml_models['kmeans'].cluster_centers_,
                columns=['Quantity (sqm)']
            )
            centroids['Cluster'] = centroids.index
            centroids['Count'] = historical['Cluster'].value_counts().sort_index().values
            
            fig_centroids = px.bar(
                centroids, 
                x='Cluster', 
                y='Quantity (sqm)',
                text='Count',
                title="Cluster Centroids with Member Count"
            )
            st.plotly_chart(fig_centroids, use_container_width=True)
        else:
            st.warning("Upload and process historical data to enable machine learning insights")

    # --------------------------
    # Data Export
    # --------------------------
    st.sidebar.divider()
    with st.sidebar.expander("📤 Export Data"):
        st.download_button(
            label="📥 Download Forecast Data",
            data=df.to_csv(index=False).encode('utf-8'),
            file_name=f"{material_name.replace(' ', '_')}_forecast.csv",
            mime='text/csv',
            help="Download the complete forecast data as CSV"
        )
        
        if not st.session_state.historical_data.empty:
            st.download_button(
                label="📥 Download Historical Data",
                data=st.session_state.historical_data.to_csv(index=False).encode('utf-8'),
                file_name="historical_data.csv",
                mime='text/csv'
            )

# --------------------------
# About Section
# --------------------------
st.sidebar.divider()
with st.sidebar.expander("About"):
    st.write("""
    **SKANEM Supply Chain Forecasting Tool**  
    Version 2.0  
    Developed for SKANEM AS  
    
    Features:  
    - Inventory forecasting  
    - Machine learning insights  
    - Time-series analysis  
    
    © 2025 SKANEM AS
    """)

