In [None]:
#%%
from datetime import date, timedelta
from itertools import product
import warnings

import pandas as pd
import numpy as np
import pyodbc
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestRegressor

from sqlalchemy import create_engine, text
from crepes import WrapRegressor
from crepes.extras import DifficultyEstimator

import os
import sys

# Folder of the current script (ML)
script_dir = os.path.dirname(os.path.abspath(__file__))

# Absolute path to Logging_Script
logging_script_dir = os.path.abspath(
    os.path.join(script_dir, "..", "Logging_Script")
)

if logging_script_dir not in sys.path:
    sys.path.append(logging_script_dir)

from logging_config import setup_logger

logger = setup_logger("PREDICTIONS")

warnings.filterwarnings(
    "ignore",
    message="The number of unique classes is greater than 50% of the number of samples",
)

today_date = date.today()
# today_date = today_date - timedelta(days=1)
#%%

def preProcessingSteps(df_to_preprocess: pd.DataFrame):

    logger.info("Preprocessing of Sales df started")
    try:
        row_count = len(df_to_preprocess)
        logger.info("Input dataframe rows: %d", row_count)

        min_date = df_to_preprocess['InvoiceDate'].min()
        max_date = df_to_preprocess['InvoiceDate'].max()
        logger.info("InvoiceDate range: %s -> %s", min_date, max_date)

        cutoff_date = max_date - timedelta(days=2*365)
        if min_date < cutoff_date:
            logger.info("Min date %s older than cutoff %s, adjusting", min_date, cutoff_date)
            min_date = cutoff_date
        logger.info("Final date range used: %s -> %s", min_date, max_date)

        all_dates = [min_date + timedelta(days=i) for i in range((max_date - min_date).days + 1)]
        Products = sorted(df_to_preprocess['Product'].dropna().unique())
        Branch = sorted(df_to_preprocess['Branch'].dropna().unique())
        logger.info("Unique Products: %d, Branches: %d", len(Products), len(Branch))

        grid_data = list(product(all_dates, Products, Branch))
        full_grid = pd.DataFrame(grid_data, columns=['InvoiceDate', 'Product', 'Branch'])
        logger.info("Full grid created with %d rows", len(full_grid))

        sales_agg = df_to_preprocess.groupby(['InvoiceDate', 'Product', 'Branch']).agg({'Quantity': 'sum'}).reset_index()
        logger.info("Aggregated sales rows: %d", len(sales_agg))

        merged = full_grid.merge(sales_agg, on=['InvoiceDate', 'Product', 'Branch'], how='left').fillna(0)
        logger.info("Merged dataset rows: %d", len(merged))

        Products = df_to_preprocess[['Product', 'TopCategory', 'LastCategory']].drop_duplicates()
        Product_merge = merged.merge(Products, on='Product', how='left')

        final_df = Product_merge.sort_values(['Product', 'Branch', 'InvoiceDate'])
        final_df['Quantity'] = final_df['Quantity'].clip(lower=0)

        final_df['month'] = final_df['InvoiceDate'].dt.month
        final_df['day_of_month'] = final_df['InvoiceDate'].dt.day
        final_df['day_of_week'] = final_df['InvoiceDate'].dt.dayofweek

        final_df['is_weekend'] = (final_df['day_of_week'] >= 5).astype(int)
        final_df['is_month_start'] = final_df['InvoiceDate'].dt.is_month_start.astype(int)
        final_df['is_month_end'] = final_df['InvoiceDate'].dt.is_month_end.astype(int)

        final_df = pd.get_dummies(final_df, columns=['month', 'day_of_month', 'day_of_week'],
                                  prefix=['Month', 'Day', 'Weekday'], drop_first=False)
        final_df = final_df.astype({col: 'int' for col in final_df.select_dtypes('bool').columns})
        logger.info("Feature engineering completed")

        le_branch = LabelEncoder()
        final_df['BranchEncoded'] = le_branch.fit_transform(final_df['Branch'])

        tomorrow_date = final_df['InvoiceDate'].max() + timedelta(days=1)
        logger.info("Tomorrow date generated: %s", tomorrow_date)

        tomorrow_combinations = final_df[['Product', 'Branch', 'TopCategory', 'LastCategory', 'BranchEncoded']].drop_duplicates()
        tomorrow_combinations['InvoiceDate'] = tomorrow_date

        tomorrow_combinations['month'] = tomorrow_combinations['InvoiceDate'].dt.month
        tomorrow_combinations['day_of_month'] = tomorrow_combinations['InvoiceDate'].dt.day
        tomorrow_combinations['day_of_week'] = tomorrow_combinations['InvoiceDate'].dt.dayofweek
        tomorrow_combinations['is_weekend'] = (tomorrow_combinations['day_of_week'] >= 5).astype(int)
        tomorrow_combinations['is_month_start'] = tomorrow_combinations['InvoiceDate'].dt.is_month_start.astype(int)
        tomorrow_combinations['is_month_end'] = tomorrow_combinations['InvoiceDate'].dt.is_month_end.astype(int)

        dummy_columns = [col for col in final_df.columns if col.startswith(('Month_', 'Day_', 'Weekday_'))]
        tomorrow_encoded = pd.get_dummies(tomorrow_combinations, columns=['month', 'day_of_month', 'day_of_week'],
                                          prefix=['Month', 'Day', 'Weekday'], drop_first=False)
        tomorrow_encoded = tomorrow_encoded.astype({col: 'int' for col in tomorrow_encoded.select_dtypes('bool').columns})

        missing_cols = 0
        for col in dummy_columns:
            if col not in tomorrow_encoded.columns:
                tomorrow_encoded[col] = 0
                missing_cols += 1
        logger.info("Missing dummy columns added for tomorrow: %d", missing_cols)

        numeric_features = ['Quantity']
        for col in numeric_features:
            tomorrow_encoded[col] = np.nan

        tomorrow_encoded = tomorrow_encoded[final_df.columns]
        final_df = pd.concat([final_df, tomorrow_encoded], ignore_index=True)
        logger.info("Preprocessing completed successfully, final rows: %d", len(final_df))

        return final_df, tomorrow_date

    except Exception:
        logger.critical("Preprocessing failed", exc_info=True)
        raise



def preForecastingSteps(df: pd.DataFrame, tomorrow_date):
    logger.info("PreForecastingSteps started")
    df = df.sort_values(['Product', 'Branch', 'InvoiceDate'])

    # Feature engineering: lag and rolling stats
    for lag in range(1, 6):
        df[f'lag_{lag}'] = df.groupby(['Product', 'Branch'])['Quantity'].transform(lambda x: x.shift(lag))
    df['rolling_mean_3'] = df.groupby(['Product','Branch'])['Quantity'].transform(lambda x: x.shift(1).rolling(3).mean())
    df['rolling_mean_7'] = df.groupby(['Product','Branch'])['Quantity'].transform(lambda x: x.shift(1).rolling(7).mean())
    df['rolling_std_3']  = df.groupby(['Product','Branch'])['Quantity'].transform(lambda x: x.shift(1).rolling(3).std())

    logger.info("Feature engineering completed: lag and rolling features added")

    # Base + lag + rolling features
    base_features = [
       'Month_1', 'Month_2', 'Month_3', 'Month_4', 'Month_5',
       'Month_6', 'Month_7', 'Month_8', 'Month_9', 'Month_10', 'Month_11',
       'Month_12', 'Day_1', 'Day_2', 'Day_3', 'Day_4', 'Day_5', 'Day_6',
       'Day_7', 'Day_8', 'Day_9', 'Day_10', 'Day_11', 'Day_12', 'Day_13',
       'Day_14', 'Day_15', 'Day_16', 'Day_17', 'Day_18', 'Day_19', 'Day_20',
       'Day_21', 'Day_22', 'Day_23', 'Day_24', 'Day_25', 'Day_26', 'Day_27',
       'Day_28', 'Day_29', 'Day_30', 'Day_31', 'Weekday_0', 'Weekday_1',
       'Weekday_2', 'Weekday_3', 'Weekday_4', 'Weekday_5', 'Weekday_6',
       'is_weekend', 'is_month_start', 'is_month_end', 
        'New_Year', 'Valentine_Day', 'Days_before_Ramazan', 'Ramazan',
        'Eid_al_Fitr', 'Days_after_Eid_al_Fitr', 'Days_before_Eid_al_Adha',
        'Eid_al_Adha', 'Muharram', 'Day_before_Independence_day',
        'Independence_day', 'Eid_Milad_un_Nabi'
    ]

    lag_features = [c for c in df.columns if c.startswith('lag_')]
    rolling_features = [c for c in df.columns if c.startswith('rolling_')]
    features = base_features + lag_features + rolling_features
    target = 'Quantity'

    df = pd.concat([
        df[df["InvoiceDate"] != tomorrow_date].dropna(subset=features + [target]),
        df[df["InvoiceDate"] == tomorrow_date]
    ], ignore_index=True)

    logger.info("PreForecastingSteps completed, dataframe ready with %d rows", len(df))
    return df, features, target


def calculate_weighted_mda(actual_changes, predicted_changes, prev_actuals):
    scores = []
    total_weight = sum(prev_actuals) + 1e-6
    for actual, predicted, prev in zip(actual_changes, predicted_changes, prev_actuals):
        weight = prev / total_weight
        sign_match = np.sign(actual) == np.sign(predicted)
        error = abs(predicted - actual)
        score = (1 / (1 + error)) if sign_match else (0.5 / (1 + error))
        scores.append(score * weight)
    return np.sum(scores) if scores else np.nan

def calculate_rmse(y_true, y_pred):
    errors = (y_true - y_pred) ** 2
    return np.sqrt(np.mean(errors)) if len(errors) > 0 else np.nan


def forecasting(df: pd.DataFrame, features: list, target: str, tomorrow_date):
    logger.info("Forecasting started for date %s", tomorrow_date)

    param_grid = [
        {"n_estimators": 150, "max_depth": 8, "max_features": None,
         "min_samples_split": 10, "min_samples_leaf": 3},
        {"n_estimators": 200, "max_depth": 10, "max_features": 0.7,
         "min_samples_split": 10, "min_samples_leaf": 3},
        {"n_estimators": 300, "max_depth": None, "max_features": 0.5,
         "min_samples_split": 12, "min_samples_leaf": 4},
    ]

    val_window_days = 20
    calib_window_days = 20
    final_result_df = pd.DataFrame()
    forecast_date = tomorrow_date
    
    for i, ((prod, branch), group) in enumerate(df.groupby(["Product", "Branch"]), start=1):
        group = group.sort_values("InvoiceDate").reset_index(drop=True)
        logger.info("Processing product %d | %s-%s with %d rows", i, prod, branch, len(group))

        # Define windows
        calib_start = forecast_date - timedelta(days=calib_window_days)
        calib_end = forecast_date - timedelta(days=1)
        val_start = calib_start - timedelta(days=val_window_days)
        val_end = calib_start - timedelta(days=1)

        # Masks
        train_mask = (group['InvoiceDate'] < val_start)
        val_mask = (group['InvoiceDate'] >= val_start) & (group['InvoiceDate'] <= val_end)
        calib_mask = (group['InvoiceDate'] >= calib_start) & (group['InvoiceDate'] <= calib_end)
        test_mask = (group['InvoiceDate'] == forecast_date)

        X_train, y_train = group.loc[train_mask, features], group.loc[train_mask, target]
        X_val, y_val = group.loc[val_mask, features], group.loc[val_mask, target]
        X_calib, y_calib = group.loc[calib_mask, features], group.loc[calib_mask, target]
        X_test = group.loc[test_mask, features]

        if len(X_test) == 0:
            logger.warning("No test rows for %s-%s on %s — skipping", prod, branch, forecast_date)
            continue

        if len(X_train) < 20 or len(X_val) < 10 or len(X_calib) < 10:
            logger.warning("%s-%s Not enough data in splits (train ≥20, val ≥10, calib ≥10) — skipping", prod, branch)
            continue

        # Hyperparameter tuning
        results_eval = []
        best_params = None
        for params in param_grid:
            try:
                model_tmp = RandomForestRegressor(n_jobs=-1, random_state=42, **params)
                model_tmp.fit(X_train.values, y_train.values)
                y_val_pred = model_tmp.predict(X_val.values)

                # Compute WMDA & RMSE
                prev_actual = y_train.iloc[-1]
                extended_actuals = np.concatenate([[prev_actual], y_val])
                predicted_extended = np.concatenate([[prev_actual], y_val_pred])
                actual_changes = (extended_actuals[1:] - extended_actuals[:-1]) / (extended_actuals[:-1] + 1e-6)
                predicted_changes = (predicted_extended[1:] - extended_actuals[:-1]) / (extended_actuals[:-1] + 1e-6)
                wmda_score = calculate_weighted_mda(actual_changes, predicted_changes, extended_actuals[:-1])
                rmse_score = calculate_rmse(y_val.values, y_val_pred)

                results_eval.append((wmda_score, rmse_score, params))
            except Exception as e:
                logger.exception("Fit failed for %s-%s with params %s: %s", prod, branch, params, e)
                continue

        if not results_eval:
            logger.warning("All parameter fits failed for %s-%s — skipping", prod, branch)
            continue

        # Normalize + hybrid scoring
        wmda_vals = np.array([r[0] for r in results_eval])
        rmse_vals = np.array([r[1] for r in results_eval])
        if np.all(np.isnan(wmda_vals)) or np.all(np.isnan(rmse_vals)):
            logger.warning("Skipping %s-%s — invalid WMDA/RMSE (all NaN)", prod, branch)
            continue

        wmda_low, wmda_high = np.nanpercentile(wmda_vals, 5), np.nanpercentile(wmda_vals, 95)
        rmse_low, rmse_high = np.nanpercentile(rmse_vals, 5), np.nanpercentile(rmse_vals, 95)
        norm = lambda val, low, high: (val - low) / (high - low + 1e-6)
        alpha, beta = 0.8, 1.0
        hybrid_scores = [
            alpha * norm(w, wmda_low, wmda_high) + beta * (1 - norm(r, rmse_low, rmse_high))
            for w, r in zip(wmda_vals, rmse_vals)
        ]

        if np.all(np.array(hybrid_scores) <= 0):
            best_idx = np.nanargmax(wmda_vals)
        else:
            best_idx = np.nanargmax(hybrid_scores)

        best_wmda, best_rmse, best_params = results_eval[best_idx]
        best_hybrid = hybrid_scores[best_idx]
        # logger.info("%s-%s: Selected best params %s | Best WMDA=%.4f | Best RMSE=%.4f | Best hybrid=%.4f", prod, branch, best_params, best_wmda, best_rmse, best_hybrid)

        # Retrain final model on train+val
        X_full = pd.concat([X_train, X_val], axis=0)
        y_full = pd.concat([y_train, y_val], axis=0)

        rf = WrapRegressor(RandomForestRegressor(n_jobs=-1, random_state=42, **best_params, oob_score=True))
        rf.fit(X_full.values, y_full.values)
        y_pred_train = rf.predict(X_full.values)
        residuals = np.abs(y_full - y_pred_train)

        de = DifficultyEstimator()
        de.fit(X_full.values, residuals=residuals, scaler=True, beta=0.01)
        rf.calibrate(X_calib.values, y_calib.values, de=de)

        # Predict on X_test
        preds = rf.predict(X_test.values)
        int_80 = rf.predict_int(X_test.values, confidence=0.8, y_min=0)
        int_90 = rf.predict_int(X_test.values, confidence=0.9, y_min=0)
        int_70 = rf.predict_int(X_test.values, confidence=0.7, y_min=0)
        int_60 = rf.predict_int(X_test.values, confidence=0.6, y_min=0)
        int_50 = rf.predict_int(X_test.values, confidence=0.5, y_min=0)

        # Attach identifiers and append
        temp = group.loc[test_mask, ['InvoiceDate','Product','Branch','TopCategory','Quantity']].copy()
        temp['Predicted'] = preds
        temp['Lower_90'], temp['Upper_90'] = zip(*int_90)
        temp['Lower_80'], temp['Upper_80'] = zip(*int_80)
        temp['Lower_70'], temp['Upper_70'] = zip(*int_70)
        temp['Lower_60'], temp['Upper_60'] = zip(*int_60)
        temp['Lower_50'], temp['Upper_50'] = zip(*int_50)
        temp['best_hybrid_score'] = best_hybrid
        temp['best_wmda'] = best_wmda
        temp['best_rmse'] = best_rmse

        if final_result_df.empty:
            final_result_df = temp.copy()
        else:
            final_result_df = pd.concat([final_result_df, temp], ignore_index=True)

    logger.info("Forecasting completed | Final result rows: %d", len(final_result_df))
    
    if final_result_df.empty:
        logger.critical("Forecasting failed: final_result_df is empty", exc_info=True)
        raise RuntimeError(
            "Forecasting produced no results. Check data availability, "
            "filters, or model training steps."
        )
    
    logger.info("Total Unique Products Forecasted: %d", final_result_df.Product.nunique())
    return final_result_df



def preprocess_and_validate_dates(df: pd.DataFrame, date_col: str) -> pd.DataFrame:

    if date_col == 'InvoiceDate':
        df_name = 'sales_df'
    elif date_col == 'date':
        df_name = 'stocks_df'
    else:
        df_name = 'unknown_df'
        logger.warning("Unknown date column '%s' provided", date_col)
        
    logger.info("Starting date preprocessing for '%s'", df_name)
    
    df = df.copy()

    try:
        # --- Clean up the date strings
        df[date_col] = (
            df[date_col]
            .astype(str)
            .str.replace(r'T', ' ', regex=True)
            .str.replace(r'(\.\d+)$', '', regex=True)
        )

        # --- Convert to datetime
        df[date_col] = pd.to_datetime(df[date_col]).dt.date
        df[date_col] = pd.to_datetime(df[date_col])
        logger.info("Date conversion to datetime completed for '%s' column", date_col)

        # --- Validate that max date equals today's date
        max_date = df[date_col].max().date()
        logger.info(
            "Max date found in %s: %s | Expected today: %s",
            df_name, max_date, today_date
        )

        if max_date != today_date and df_name == 'sales_df':
            logger.critical(
                "Date validation failed for %s | Max %s = %s, Expected = %s",
                df_name, date_col, max_date, today_date, exc_info=True
            )
            raise ValueError(
                f"The latest {date_col} ({max_date}) in {df_name} "
                f"does not match today's date ({today_date})."
            )

        logger.info("Date validation Done for %s", df_name)
        return df

    except Exception:
        logger.critical("Preprocessing failed for %s using column '%s'", df_name, date_col, exc_info=True)
        raise
#%%


# ============================================================
# SQL Server connection configuration
# ============================================================
logger.info("Initializing database connection")

conn_str = (
    "DRIVER={ODBC Driver 17 for SQL Server};"
    "SERVER=;"
    "DATABASE=united_king;"
    "UID=;"
    "PWD=;"
    "Trusted_Connection=no;"
    "Connection Timeout=600;"
)

engine_str = "mssql+pyodbc:///?odbc_connect=" + conn_str.replace(" ", "%20")

try:
    engine = create_engine(engine_str)
    with engine.connect() as conn:
        logger.info("Database connection successful")
except Exception as e:
    logger.critical("Database connection failed", exc_info=True)
    raise ConnectionError(f" Database connection failed: {e}")

#%%

# ============================================================
# Read product and branch filters
# ============================================================
logger.info("Reading product and branch filters")

# Absolute paths to data files
products_path = os.path.join(script_dir, "products.txt")
branches_path = os.path.join(script_dir, "branches.txt")

try:
    with open(products_path, "r", encoding="utf-8") as f:
        products = list(set(line.strip() for line in f if line.strip()))
        logger.info(f"Loaded {len(products)} products from product.txt")
    
    with open(branches_path, "r", encoding="utf-8") as f:
        branches = list(set(line.strip() for line in f if line.strip()))
        logger.info(f"Loaded {len(branches)} branches from branches.txt")
except Exception as e:
    logger.critical("Could not load input files", exc_info=True)
    raise RuntimeError(" File loading failed") from e   

# ============================================================
# Load sales data
# ============================================================
product_params = [f":p{i}" for i in range(len(products))]
branch_params = [f":b{i}" for i in range(len(branches))]

query_sales = text(f"""
    SELECT *
    FROM sales
    WHERE Product IN ({", ".join(product_params)})
      AND Branch IN ({", ".join(branch_params)});
""")

query_stocks = text("SELECT * FROM stocks;")
query_master = text("SELECT * FROM products_master;")

# Convert literal "\xa0" to real non-breaking space
products_clean = [p.encode("utf-8").decode("unicode_escape") for p in products]

params = {f"p{i}": prod for i, prod in enumerate(products_clean)}
params.update({f"b{i}": br for i, br in enumerate(branches)})

#%%

logger.info("Starting data extraction from database")

try:
    # ============================================================
    # Load all necessary tables in a single connection
    # ============================================================
    with engine.connect() as conn:
        sales_df = pd.read_sql(query_sales, conn, params=params)
        stocks_df = pd.read_sql(query_stocks, conn)
        products_master_df = pd.read_sql(query_master, conn)
        festivals_df = pd.read_sql("SELECT * FROM festival_calendar;", conn)

    logger.info(
        f"Data loaded | sales: {len(sales_df)}, stocks: {len(stocks_df)}, master: {len(products_master_df)}, festivals: {len(festivals_df)}"
    )

    # ============================================================
    # Preprocess dates
    # ============================================================
    sales_df = preprocess_and_validate_dates(sales_df, 'InvoiceDate')
    stocks_df = preprocess_and_validate_dates(stocks_df, 'date')
    festivals_df['date'] = pd.to_datetime(festivals_df['date'])

    if sales_df.empty:
        logger.critical("Sales dataframe is empty after preprocessing — stopping pipeline", exc_info=True)
        raise RuntimeError(" No sales data available for processing.")
        
    missed_prod = set(products_clean) - set(sales_df.Product.unique())
    if missed_prod:
        logger.warning("Products missing in sales data: %s", missed_prod)

    # ============================================================
    # Preprocessing steps
    # ============================================================
    processed_df, tomorrow_date = preProcessingSteps(sales_df)

    if processed_df.empty:
        logger.critical("Processed dataframe is empty — stopping pipeline", exc_info=True)
        raise RuntimeError(" No processed data available to merge or forecast.")

    # ============================================================
    # Merge with festivals
    # ============================================================
    logger.info("Merging processed data with festival calendar")
    processed_df = processed_df.merge(
        festivals_df,
        how='left',
        left_on='InvoiceDate',
        right_on='date'
    ).drop(columns=['date'])
    logger.info("Festival merge completed")

    # ============================================================
    # Pre-forecasting steps
    # ============================================================
    df, features, target = preForecastingSteps(processed_df, tomorrow_date)

except Exception:
    logger.critical("Pipeline failed unexpectedly", exc_info=True)
    raise

#%%

prediction_results = forecasting(df, features, target, tomorrow_date)
stocks_df = stocks_df.rename(columns={'date': 'InvoiceDate'})

#%%

socks_and_masters = stocks_df.merge(products_master_df[['Id', 'Name', 'IsDecimalAllow']].rename(columns={'Id': 'id', 'Name': 'Product'}),
                                    on='id', how='outer')

socks_and_masters = socks_and_masters[(socks_and_masters['CompanyBranch'].isin(branches)) & (socks_and_masters['Product'].isin(products_clean))].rename(columns={'CompanyBranch': 'Branch'})

prediction_results = prediction_results[['InvoiceDate', 'Product', 'Branch', 'TopCategory', 'Predicted', 'Upper_80']].copy()
prediction_results['Run_Date'] = pd.to_datetime(today_date)

order_qty = (socks_and_masters.rename(columns={'InvoiceDate': 'Run_Date'})
             .merge(
                 prediction_results, 
                 on=['Product', 'Branch', 'Run_Date'], 
                 how='right'
                )
)

order_qty['Inventory'] = order_qty['Inventory'].clip(lower=0).fillna(0)

shelf_life = 1
# Round conditionally based on 'IsDecimalAllow'
order_qty['IsDecimalAllow'] = order_qty['IsDecimalAllow'].fillna(0)
order_qty['Predicted'] = np.where(
    order_qty['IsDecimalAllow'] == 1,
    order_qty['Predicted'].round(2),
    order_qty['Predicted'].round()
)

order_qty['Upper_80'] = np.where(
    order_qty['IsDecimalAllow'] == 1,
    order_qty['Upper_80'].round(2),
    order_qty['Upper_80'].round()
)

order_qty['order_qty_upper80'] = np.where(
    order_qty['IsDecimalAllow'] == 1,
    (order_qty['Upper_80'] - order_qty['Inventory']).clip(lower=0).round(2),
    (order_qty['Upper_80'] - order_qty['Inventory']).clip(lower=0).round()
)

order_qty['order_qty_predicted'] = np.where(
    order_qty['IsDecimalAllow'] == 1,
    ((order_qty['Predicted'] * shelf_life) - order_qty['Inventory']).clip(lower=0).round(2),
    ((order_qty['Predicted'] * shelf_life) - order_qty['Inventory']).clip(lower=0).round()
)

order_qty = order_qty[['id', 'Product', 'Branch', 'CompanyBranchId', 'InvoiceDate', 'TopCategory', 'Inventory', 'Predicted', 'Upper_80', 'order_qty_upper80', 'order_qty_predicted', 'Run_Date']]
order_qty.columns = order_qty.columns.str.lower()

logger.info("Order Quantity Created")

#%%

# # Pushing Results To SQL
conn = pyodbc.connect(conn_str)
conn.autocommit = True  # ensures fast insert

def write_sql_pyodbc(df, table, conn):
    cursor = conn.cursor()
    cursor.fast_executemany = True
    
    # Get columns from table
    cursor.execute(f"SELECT TOP 0 * FROM {table}")
    sql_cols = [desc[0] for desc in cursor.description]
    
    # Ensure column order matches
    df = df[sql_cols]
    
    placeholders = ", ".join(["?"] * len(sql_cols))
    columns_str = ", ".join(f"[{col}]" for col in sql_cols)
    sql = f"INSERT INTO {table} ({columns_str}) VALUES ({placeholders})"
    
    cursor.executemany(sql, df.values.tolist())
    cursor.close()

def exec_sql(query):
    cursor = conn.cursor()
    # Execute with parameter to prevent SQL injection
    cursor.execute(query)
    # Commit the changes
    conn.commit()
    # Close the cursor and connection
    cursor.close()

for col in list(order_qty.columns):
    order_qty[col] = order_qty[col].astype(str).replace({"": None, "nan": None, "NaT": None, "None": None})

logger.info(f"Checking existing records in results table for run_date={today_date}")

check_query = f"""
SELECT COUNT(1)
FROM results
WHERE run_date = '{today_date}'
"""

delete_query = f"""
DELETE FROM results 
WHERE run_date = '{today_date}'
"""

existing_count = pd.read_sql(check_query, engine).iloc[0, 0]

if existing_count > 0:
    logger.warning(
        f"Existing records found for run_date={today_date} | rows={existing_count}. Deleting before insert."
    )
    exec_sql(delete_query)
    logger.info("Old records deleted successfully")
else:
    logger.info("No existing records found for today — fresh insert")

rows_to_insert = len(order_qty)
logger.info(f"Uploading predictions to results table | rows={rows_to_insert}")

write_sql_pyodbc(order_qty, "results", conn)

logger.info("Predictions successfully uploaded to results table")

read_query = f"""SELECT [id]
                      ,[product]
                      ,[branch]
                      ,[invoicedate]
                      ,[topcategory]
                      ,[inventory]
                      ,[predicted]
                      ,[order_qty_predicted]
                      ,[upper_80]
                      ,[order_qty_upper80]
                  FROM [united_king].[dbo].[results]
                  WHERE [run_date] = '{today_date}'"""
                  
result_df = pd.read_sql(read_query, engine)
result_df['IsInvReceived'] = np.where(result_df['inventory'] > 0, 'Yes', 'No')

result_df.to_excel("united_king_predictions.xlsx", index=False)

logger.info("Sent results to Excel file: united_king_predictions.xlsx, Saved at %s", os.getcwd())