In [1]:
import numpy as np
import pandas as pd
import math
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from path import Path
from collections import defaultdict
import os
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from dotenv import load_dotenv
import json
import warnings
warnings.filterwarnings('ignore')

In [2]:
# Load environment variables
load_dotenv()
FOLDER_PATH = os.getenv('FOLDER_PATH')
FILE_NAME = os.getenv('FILE_NAME')
# Select the interpolation method
interpolation_method = 'linear'#, 'akima', 'pchip', 'quadratic'

In [3]:
def is_date_column(col):
    # Function to check if column names are dates
    try:
        pd.to_datetime(col)  # Try converting the column name to a date
        return True
    except:
        return False
    
def load_filtered_csv(file_name):
    # Function to load a csv file with filtering
    # Load json file
    with open("config_SingleDataFile.json", "r") as file:
        config = json.load(file)

    """Loads a csv file, applies column filtering, and returns a dataframe."""
    # REPLACE WITH TRY EXCEPT !!!!!!!!!!
    file_path = os.path.join(FOLDER_PATH, file_name)  # Join folder path with file name
    if not Path(file_path).exists():
        print(f"File {file_path} not found.")
        return None
    
    # Load the Excel file
    xls = pd.ExcelFile(file_path, engine="openpyxl")
    df = pd.read_excel(xls, sheet_name = config[file_name]["sheet_name"])

    columns_filtering = config[file_name]["cultivar"]
    # Filter based on cultivar
    for k,v in columns_filtering.items():
        filtered_df_control_hayward = df.loc[(df[k].str.strip() == v)]

    # Remove rows with any NaN values
    filtered_df_control_hayward = filtered_df_control_hayward.dropna()

    # Reset index 
    filtered_df_control_hayward = filtered_df_control_hayward.reset_index(drop=True)

    # Identify date columns & convert columns to datetime format (if necessary)
    filtered_df_control_hayward = filtered_df_control_hayward.rename(columns={col: pd.to_datetime(col).dayofyear for col in filtered_df_control_hayward.columns if is_date_column(col)})

    # filter based on treatments
    global treatments
    treatments_column_name, treatments = list(config[file_name]["treatments"].items())[0]
    treatments_df = {treatment:filtered_df_control_hayward.loc[(filtered_df_control_hayward[treatments_column_name]==treatment)] for treatment in treatments}

    return treatments_df

def interpolate_over_sum(treatments_df):
    cumulative_buds = defaultdict()
    new_buds = defaultdict()
    cumulative_std = defaultdict()
    new_buds_std = defaultdict()
    all_trtmnts = defaultdict(lambda: pd.Series(0, index=full_date_range, dtype='float64'))

    global full_date_range
    # Identify columns that are dates (columns for keys in treatments_df are the same, here selecting the 1st key)
    date_stamps = [col for col in list(treatments_df.values())[0].columns if (type(col)==int and int(col)>1 and int(col)<356)]

    for k in treatments_df.keys():
        # Compute cumulative average for each date column
        cumulative_buds[k] = treatments_df[k][date_stamps].mean()
        # Compute the std for cumulative bud number
        cumulative_std[k] = treatments_df[k][date_stamps].std(ddof=0)

        # Compute the difference between consecutive values, keeping the first value unchanged
        new_buds[k] = cumulative_buds[k].copy()
        new_buds[k].iloc[1:] = new_buds[k][date_stamps].diff().iloc[1:]     

        # calculate std for new buds
        series = pd.Series([treatments_df[k][d] for d in date_stamps], index=date_stamps, dtype='float64')
        # Compute element-wise differences between consecutive lists
        diff_series = series.copy()
        diff_series[1:] = [np.array(series.iloc[i]) - np.array(series.iloc[i - 1]) for i in range(1, len(series))]
        # Compute standard deviation for each difference list
        std_values = [np.std(diff, ddof=1) for diff in diff_series]  # ddof=1 for sample std
        new_buds_std[k] = pd.Series(std_values, index=date_stamps)        
        
        # Create a complete date range from the min to max date
        # full_date_range = pd.date_range(start=min(cumulative_buds[k].index), end=max(cumulative_buds[k].index)) # when we have datetimes
        full_date_range = np.arange(start=min(cumulative_buds[k].index), stop=max(cumulative_buds[k].index)+1, step=1) # when we have day of the year

        # Reindex to include missing dates
        cumulative_buds[k] = cumulative_buds[k].reindex(full_date_range)
        new_buds[k] = new_buds[k].reindex(full_date_range)
        cumulative_std[k] = cumulative_std[k].reindex(full_date_range)
        new_buds_std[k] = new_buds_std[k].reindex(full_date_range)
        
    # Interpolate missing values
    cumulative_interpol = {interpolation_method:{k:cumulative_buds[k].interpolate(method=interpolation_method) for k in treatments_df.keys()}}
    new_buds_interpol = {interpolation_method:{k:new_buds[k].interpolate(method=interpolation_method) for k in treatments_df.keys()}}

    for trtmnt in treatments:
        all_trtmnts['cumulative'] += cumulative_interpol[interpolation_method][trtmnt]
        all_trtmnts['new'] += new_buds_interpol[interpolation_method][trtmnt]

    return cumulative_interpol, new_buds_interpol, all_trtmnts, cumulative_std, new_buds_std

def poly_fit(df, degree):
    x = full_date_range
    y = df.values  # Extract values
    # Fit a polynomial of degree (n)
    coefficients = np.polyfit(x, y, degree)  # Get polynomial coefficients
    polynomial = np.poly1d(coefficients)  # Create polynomial function
    # Generate fitted values
    return polynomial(x)

def calculate_fit_stats(original_values, fitted_values):
    # Compute statistical metrics
    mse = mean_squared_error(original_values, fitted_values) # Penalizes large errors
    rmse = np.sqrt(mse) # Easier to interpret (same unit as data)
    mae = mean_absolute_error(original_values, fitted_values) # Measures absolute errors
    r2 = r2_score(original_values, fitted_values) # Explains variance (0-1 range), closer to 1 is better
    nrmse = rmse / (original_values.max() - original_values.min()) # Typically, NRMSE < 0.1 is considered a good fit
    # Return results as a dictionary
    return {
        "MSE": mse,
        "RMSE": rmse,
        "MAE": mae,
        "R2": r2,
        "NRMSE": nrmse
    }

class InteractivePlotGen:
    def __init__(self, cumulative, new, all_trtmnts, fitted_poly, stats_cumulative, stats_new, cumulative_std, new_buds_std):
        self.cumulative = cumulative
        self.new = new
        self.all_trtmnts = all_trtmnts
        self.fitted_poly = fitted_poly
        self.stats_cumulative = stats_cumulative
        self.stats_new = stats_new
        self.cumulative_std = cumulative_std
        self.new_buds_std = new_buds_std
    def _BB_plot(self):
        # Create 2 subplots for cumulative and new buds over time
        fig = make_subplots(rows=2, cols=1, subplot_titles=[f"Cumulative Buds [R2: {self.stats_cumulative['R2']:.2f}]", f"Daily New Buds [R2: {self.stats_new['R2']:.2f}]"])

        # Add first subplot (Cumulative Sum)
        [fig.add_trace(
            go.Scatter(x=full_date_range, y=self.cumulative[interpolation_method][treatment].values, mode='lines+markers', error_y=dict(type='data', array=self.cumulative_std[treatment], visible=True), name=treatment), row=1, col=1) for treatment in treatments]
        fig.add_trace(
            go.Bar(x=full_date_range, y=self.all_trtmnts['cumulative'].values, name='All Trtmnts (Cumulative)'), row=1, col=1)
        # Fit polynomial over all treatments
        fig.add_trace(
            go.Scatter(x=full_date_range, y=self.fitted_poly['cumulative'], mode='lines', line=dict(color="black", dash="dash") , name='All Trtmnts (Cumulative)--fitted'),
            row=1, col=1)

        # Add second subplot (Daily New Buds)
        [fig.add_trace(
            go.Scatter(x=full_date_range, y=self.new[interpolation_method][treatment].values, mode='lines+markers', error_y=dict(type='data', array=self.new_buds_std[treatment], visible=True), name=treatment), row=2, col=1) for treatment in treatments]
        fig.add_trace(
            go.Bar(x=full_date_range, 
                y=self.all_trtmnts['new'].values, name='All Trtmnts (New)'), row=2, col=1)
        # Fit polynomial over all treatments
        fig.add_trace(
            go.Scatter(x=full_date_range, y=self.fitted_poly['new'], mode='lines', line=dict(color="black", dash="dash") , name='All Trtmnts (New)--fitted'),
            row=2, col=1)

        # Set y-axis labels for each subplot
        [fig.update_yaxes(title_text="Mean bud number", row=row_num, col=1) for row_num in [1,2]]
        # Update layout
        fig.update_layout(
            title="Cumulative & Daily New Buds Over Time",
            showlegend=True,
            height=700, width=1200)

        fig.update_xaxes(
            tickmode="array", 
            tickvals=full_date_range,  # Use all available dates as ticks
            tickangle=45)  # Rotate for better visibility

        # Save as interactive HTML
        fig.write_html("bud_num_interactive_plot.html")
    def _origin_predict_plot(self):
        # Create 2 subplots for cumulative and new buds (original vs predicted)
        fig = make_subplots(rows=2, cols=1, subplot_titles=["Cumulative Buds", "Daily New Buds"])

        fig.add_trace(go.Scatter(
            x=self.fitted_poly['cumulative'], y=self.all_trtmnts['cumulative'].values, mode='markers',
            name='Original vs Fitted', marker=dict(color='blue')), row=1, col=1)
        fig.add_trace(go.Scatter(
            x=self.all_trtmnts['cumulative'].values, y=self.all_trtmnts['cumulative'].values, mode='lines',
            name='Cumulative 1:1', marker=dict(color='red')), row=1, col=1)
        
        # Second subplot
        fig.add_trace(go.Scatter(
            x=self.fitted_poly['new'], y=self.all_trtmnts['new'].values, mode='markers',
            name='Original vs Fitted', marker=dict(color='blue')), row=2, col=1)
        fig.add_trace(go.Scatter(
            x=self.all_trtmnts['new'].values, y=self.all_trtmnts['new'].values, mode='lines',
            name='New 1:1', marker=dict(color='red')), row=2, col=1)
        
        # Add stats as an annotation
        fig.add_annotation(
            text=f"R2 = {self.stats_cumulative['R2']:.2f}<br>NRMSE = {self.stats_cumulative['NRMSE']:.2f}",
            x=self.fitted_poly['cumulative'][int(len(self.fitted_poly['cumulative'])/3)],  # Position at 1/3 of x-axis
            y=max(self.all_trtmnts['cumulative'].values),      # Position near the top of y-axis
            showarrow=False,
            xref="x1",  # Referencing x-axis for subplot 1
            yref="y1",  # Referencing y-axis for subplot 1
            font=dict(size=10, color="black"),
            align="left",
            bordercolor="black",
            borderwidth=1,
            bgcolor="white")
        
        fig.add_annotation(
            text=f"R2 = {self.stats_new['R2']:.2f}<br>NRMSE = {self.stats_new['NRMSE']:.2f}",
            x=self.fitted_poly['new'][int(len(self.fitted_poly['new'])/3)],  # Position at 1/3 of x-axis
            y=max(self.all_trtmnts['new'].values),      # Position near the top of y-axis
            showarrow=False,
            xref="x2",  # Referencing x-axis for subplot 2
            yref="y2",  # Referencing y-axis for subplot 2
            font=dict(size=10, color="black"),
            align="left",
            bordercolor="black",
            borderwidth=1,
            bgcolor="white")
        
        # Set x-axis & y-axis labels for each subplot
        [fig.update_yaxes(title_text="Original", title_font=dict(size=10), row=row_num, col=1) for row_num in [1,2]]
        [fig.update_xaxes(title_text="Predicted", title_font=dict(size=10), row=row_num, col=1) for row_num in [1,2]]
        # Customize layout
        fig.update_layout(title="Original vs. Fitted Data", template="plotly_white")

        # Save as interactive HTML file
        fig.write_html("original vs predicted.html")

In [4]:
fitted_poly = defaultdict()
treatments_df = load_filtered_csv(FILE_NAME)
cumulative_interpol, new_buds_interpol, all_trtmnts, cumulative_std, new_buds_std = interpolate_over_sum(treatments_df)
fitted_poly['cumulative'] = poly_fit(all_trtmnts['cumulative'], 2)
fitted_poly['new'] = poly_fit(all_trtmnts['new'], 6)
# Calculate stats for all_trtmnts & fitted_poly
stats_cumulative = calculate_fit_stats(all_trtmnts['cumulative'].values, fitted_poly['cumulative'])
stats_new = calculate_fit_stats(all_trtmnts['new'].values, fitted_poly['new'])
plotter = InteractivePlotGen(cumulative_interpol, new_buds_interpol, all_trtmnts, fitted_poly, stats_cumulative, stats_new, cumulative_std, new_buds_std)
plotter._BB_plot()
plotter._origin_predict_plot()