In [None]:
import json
import warnings
from glob import glob
from importlib import reload

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from cmdstanpy import stanfit

from nteprsm import utils
from nteprsm.constants import MONTH_ABBR, MONTH_BINS
from settings import ROOT_DIR
import utils as notebook_utils

# use customize plotly template
notebook_utils.set_custom_template()
reload(notebook_utils)
reload(utils)

warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load model configuration
config_file = ROOT_DIR/"config/nteprsm_nj2kbg07.yml"
config = utils.load_config(config_file)
# load posterier samples from csv files
files = glob(str(ROOT_DIR/config["sampling"]["output_dir"]
                 /"nteprsm_turf_annual_seasonality-*.csv"))
print(files)
fit = stanfit.from_csv(files)

['/Users/henryqu/Documents/GitHub/nteprsm/data/model_output/nj2/nteprsm_turf_annual_seasonality-20240421_3.csv', '/Users/henryqu/Documents/GitHub/nteprsm/data/model_output/nj2/nteprsm_turf_annual_seasonality-20240421_2.csv', '/Users/henryqu/Documents/GitHub/nteprsm/data/model_output/nj2/nteprsm_turf_annual_seasonality-20240421_1.csv', '/Users/henryqu/Documents/GitHub/nteprsm/data/model_output/nj2/nteprsm_turf_annual_seasonality-20240421_4.csv']


In [3]:
# process data
datahandler = utils.DataHandler(filepath=ROOT_DIR/config["data_path"],)
datahandler.load_data()
datahandler.preprocess_data()
datahandler.generate_stan_data(**config["stan_additional_data"])

2025-03-08 23:30:23,065 - NtepRsm - INFO - Loading data from /Users/henryqu/Documents/GitHub/nteprsm/data/raw/quality_nj2.csv...
2025-03-08 23:30:23,079 - NtepRsm - INFO - Start preprocessing data...
2025-03-08 23:30:23,103 - NtepRsm - INFO - Data preprocessing completed.
2025-03-08 23:30:23,103 - NtepRsm - INFO - Generating data dictionary for the model...



# Background Information
The turfgrass plots are established in multiple locations as shown below. Each plot is assigned a coordinate based on its row and column number within the grid design of the trial. 
![Turfgrass Trial Locations](../reports/figures/coolseason_turfgrass_trial_example.jpg)


In [4]:
model_data = datahandler.model_data
raw_data = model_data[['entry_name', 'entry_name_code', 'date', 'adj_day_of_year', 'quality', 'row', 'col', 'plt_id','test_loc']]
raw_data.to_csv(ROOT_DIR/"data/processed/raw_data.csv", index=False)
raw_data.head()

Unnamed: 0,entry_name,entry_name_code,date,adj_day_of_year,quality,row,col,plt_id,test_loc
0,NAI-14-132,62,2018-04-18,108,4,17,1,241,"Adelphia, NJ"
1,NAI-14-132,62,2018-05-10,130,4,17,1,241,"Adelphia, NJ"
2,NAI-14-132,62,2018-06-07,158,5,17,1,241,"Adelphia, NJ"
3,NAI-14-132,62,2018-07-17,198,5,17,1,241,"Adelphia, NJ"
4,NAI-14-132,62,2018-08-15,227,6,17,1,241,"Adelphia, NJ"


# 1. Visualization of seasonality for a single location
The trial spans multiple years, and we developed a model to extract the seasonality of the turfgrass entries for an average year. To consolidate data from all five years into this average year, we created the `adj_day_of_year`. This adjusted day of the year aligns dates after February 28 in a leap year with the same “day number” they would have in a non-leap year, effectively ignoring the shift caused by February 29.

In [5]:
sampleanalysis = utils.PosteriorSampleAnalysis(datahandler, fit)

2025-03-08 23:30:32,996 - NtepRsm - DEBUG - Logging is already configured.


In [6]:
# mean predictions
pred_means = sampleanalysis.get_predicted_statistics(np.mean)
pred_means.to_csv(ROOT_DIR/"data/processed/nj2_seasonality_pred_means.csv")
pred_means.head()

entry_name_code,0,1,2,3,4,5,6,7,8,9,...,79,80,81,82,83,84,85,86,87,88
adj_day_of_year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
3.65,0.004752,0.435733,1.245416,-1.126459,-0.052662,-0.168781,0.943314,0.061763,-1.391306,1.615321,...,-0.687067,-0.9471,-0.102299,-1.213266,0.779621,1.31472,0.785175,0.036597,0.277514,-0.674089
7.3,-0.008615,0.44902,1.247659,-1.149227,-0.030852,-0.16277,0.953147,0.052261,-1.426613,1.622098,...,-0.698909,-0.971476,-0.125813,-1.233578,0.77535,1.287427,0.79989,0.03412,0.266112,-0.680887
10.95,-0.020547,0.461056,1.249783,-1.170166,-0.013886,-0.157877,0.959976,0.044439,-1.457709,1.628624,...,-0.710045,-0.993231,-0.145589,-1.252871,0.770651,1.263383,0.812494,0.030597,0.255264,-0.688849
14.6,-0.031292,0.472257,1.251264,-1.189398,-0.000999,-0.153717,0.963764,0.038095,-1.48499,1.634948,...,-0.720815,-1.012745,-0.161601,-1.271488,0.765383,1.241786,0.82293,0.025774,0.24438,-0.698495
18.25,-0.041098,0.483137,1.251648,-1.207335,0.008651,-0.15002,0.964498,0.032898,-1.509093,1.641162,...,-0.731669,-1.030501,-0.17413,-1.289939,0.759367,1.221761,0.830976,0.019356,0.23301,-0.710261


In [7]:
# lower bound for predictions
pred_lb = sampleanalysis.get_predicted_statistics(np.quantile, 0.025)
pred_lb.to_csv(ROOT_DIR/"data/processed/nj2_seasonality_pred_lb.csv")
pred_lb.head()

entry_name_code,0,1,2,3,4,5,6,7,8,9,...,79,80,81,82,83,84,85,86,87,88
adj_day_of_year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
3.65,-1.313905,-0.834954,-0.088014,-2.412748,-1.335411,-1.439639,-0.358867,-1.217672,-2.684949,0.362007,...,-1.922721,-2.266421,-1.400069,-2.487029,-0.516962,-0.014189,-0.502409,-1.267193,-1.084493,-1.973391
7.3,-1.326605,-0.823256,-0.104954,-2.440139,-1.310276,-1.452919,-0.351496,-1.232068,-2.725655,0.36828,...,-1.957875,-2.294716,-1.439304,-2.535252,-0.516879,-0.025779,-0.487116,-1.276325,-1.118898,-2.006571
10.95,-1.319419,-0.818988,-0.102944,-2.464031,-1.287039,-1.451434,-0.339654,-1.250245,-2.779333,0.342527,...,-1.986119,-2.312465,-1.463,-2.562204,-0.525286,-0.057508,-0.478278,-1.288612,-1.125623,-2.023661
14.6,-1.342322,-0.807726,-0.075708,-2.503276,-1.27142,-1.482225,-0.330402,-1.254099,-2.820148,0.333759,...,-2.00662,-2.316748,-1.479633,-2.599666,-0.539269,-0.08281,-0.500981,-1.296667,-1.096522,-2.035342
18.25,-1.358149,-0.807254,-0.07483,-2.524197,-1.268032,-1.490416,-0.336395,-1.268345,-2.858038,0.352816,...,-2.042524,-2.326162,-1.497919,-2.611091,-0.53788,-0.107109,-0.498908,-1.309955,-1.092455,-2.05863


In [8]:
# upper bound for predictions
pred_ub = sampleanalysis.get_predicted_statistics(np.quantile, 0.975)
pred_ub.to_csv(ROOT_DIR/"data/processed/nj2_seasonality_pred_ub.csv")
pred_ub.head()

entry_name_code,0,1,2,3,4,5,6,7,8,9,...,79,80,81,82,83,84,85,86,87,88
adj_day_of_year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
3.65,1.31762,1.724149,2.574621,0.21411,1.206651,1.124508,2.258245,1.342083,-0.071458,2.924854,...,0.605358,0.342024,1.186018,0.058326,2.083911,2.650299,2.056775,1.38812,1.602967,0.619068
7.3,1.310659,1.739149,2.579097,0.192054,1.233242,1.140906,2.270542,1.341087,-0.113147,2.942757,...,0.607434,0.355813,1.192842,0.058283,2.089649,2.617532,2.111338,1.381388,1.607446,0.624507
10.95,1.298543,1.743478,2.588833,0.168232,1.258503,1.130407,2.262249,1.322721,-0.147317,2.93718,...,0.601721,0.336813,1.170032,0.027526,2.086585,2.585351,2.137333,1.387452,1.610553,0.618699
14.6,1.271547,1.767762,2.587456,0.151336,1.262036,1.13959,2.246707,1.315772,-0.168027,2.942371,...,0.57993,0.316114,1.15959,0.011919,2.074026,2.570919,2.133868,1.377685,1.614561,0.631236
18.25,1.253609,1.781415,2.588257,0.13581,1.279393,1.152689,2.244444,1.326755,-0.21883,2.94094,...,0.562959,0.312374,1.1403,0.000957,2.064594,2.545156,2.145142,1.345461,1.588522,0.605495


In [None]:
# fitted values as individual data point on the plot
fit_data = pd.DataFrame(fit.time_effect.mean(axis=0)).T
fit_data.columns.name = 'entry_name_code'
fit_data.index = raw_data.query("entry_name_code == 0")["adj_day_of_year"]
fit_data.sort_index(inplace=True)
fit_data.drop_duplicates(inplace=True)
fit_data.to_csv(ROOT_DIR/"data/processed/nj2_seasonality_fit_data.csv")
fit_data.head()

entry_name_code,0,1,2,3,4,5,6,7,8,9,...,79,80,81,82,83,84,85,86,87,88
adj_day_of_year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
89,-0.392243,0.816396,0.745322,-1.821427,-0.059199,-0.636364,0.279271,-0.641194,-1.582109,1.720772,...,-1.144688,-1.321636,-0.246974,-1.52442,0.348218,0.618294,0.421941,-0.687665,-0.621524,-1.452892
107,-0.26141,0.618856,1.017378,-1.588908,0.073695,-0.948468,0.537034,-0.490224,-1.166286,1.536755,...,-0.799509,-1.108106,-0.262541,-0.892334,0.585954,0.997063,0.660259,-0.60259,-0.781708,-0.834093
108,-0.247863,0.602891,1.045762,-1.569332,0.096148,-0.952969,0.564045,-0.464797,-1.146966,1.525225,...,-0.775268,-1.099675,-0.256125,-0.859412,0.608526,1.025779,0.691945,-0.591688,-0.786589,-0.789647
112,-0.190312,0.538051,1.16878,-1.488994,0.204377,-0.953003,0.68015,-0.351532,-1.079859,1.480161,...,-0.678318,-1.074319,-0.223395,-0.740596,0.702632,1.139592,0.83359,-0.544248,-0.802594,-0.614175
118,-0.101641,0.442617,1.370806,-1.370603,0.414026,-0.900681,0.86922,-0.16497,-1.012916,1.418526,...,-0.541666,-1.063161,-0.161619,-0.610201,0.841263,1.293356,1.07273,-0.463382,-0.813675,-0.377043


In [10]:
# all the entries to plot
code2name = datahandler.map_name2code('entry_name', 'entry_name_code', invert=True)
with open(ROOT_DIR/"data/processed/code2name.json", 'w') as f:
	json.dump(code2name, f)
# We also need to figure out the order of the entries. In this case, I am using the overal average for the whole year. 
monthly_avg = sampleanalysis.get_predicted_monthly_means(pred_means)
entry_codes_to_plot = code2name.keys()   
entry_codes_to_plot = sorted(entry_codes_to_plot, key=lambda e: monthly_avg.mean(axis=1).loc[e], 
                reverse=True
            )
entry_codes_to_plot[:5]

[9, 69, 26, 2, 12]

In [11]:
# now we can iterate over all the entrie and plot them. 
# Set up Plotly graph
fig = go.Figure()
# for consistency of colors
colors = px.colors.qualitative.Dark24

# prepare variables and plot
for ix, code in enumerate(entry_codes_to_plot):
    entry_name = code2name[code]
    # we start with fitted values
    fig.add_trace(
        go.Scatter(
            x=fit_data.index.values,
            y=fit_data[code],
            mode="markers",
            marker=dict(size=5, color=colors[ix % len(colors)]),
            name=entry_name,
            legendgroup=entry_name,
        )
    )

    # plot predicted means for the whole year
    x_pred = pred_means.index
    y_pred = pred_means[code]
    fig.add_trace(
        go.Scatter(
            x=x_pred,
            y=y_pred,
            mode="lines",
            line=dict(width=1.5,color=colors[ix % len(colors)]),
            name=entry_name,
            legendgroup=entry_name,
            showlegend=False,
            hoverinfo="none",
        )
    )
    # add confidence interval
    y_lb = pred_lb[code]
    y_ub = pred_ub[code]
    fig.add_trace(
        go.Scatter(
            x=x_pred,
            y=y_lb,
            mode="lines",
            line=dict(width=0.5,),
            name=entry_name,
            legendgroup=entry_name,
            showlegend=False,
            hoverinfo="none",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x_pred,
            y=y_ub,
            mode="lines",
            line=dict(width=0.5, color=colors[ix % len(colors)]),
            name=entry_name,
            legendgroup=entry_name,
            showlegend=False,
            fill="tonexty",
            hoverinfo="none",
        )
    )
# update layout, especially the time of year for x axis
fig.update_layout(
    # title="Mean Time Effect",
    xaxis=dict(
        tickmode="array",
        tickvals=MONTH_BINS[:-1],
        ticktext=MONTH_ABBR,
        tickfont=dict(size=12),
    ),
    yaxis_title="Estimated Seasonality in Turf Quality",
    yaxis=dict(title_font=dict(size=12)),
    legend=dict(font=dict(size=12)),  # Increase the font size for the legend
    # title_font=dict(size=24),
)
fig.show()

# 2. Visualization of field variation in a single locations.
Once select a given location, e.g., Adelphia, New Jersey (NJ2), we can also visualize the within trial variation. 

In [12]:
raw_data

Unnamed: 0,entry_name,entry_name_code,date,adj_day_of_year,quality,row,col,plt_id,test_loc
0,NAI-14-132,62,2018-04-18,108,4,17,1,241,"Adelphia, NJ"
1,NAI-14-132,62,2018-05-10,130,4,17,1,241,"Adelphia, NJ"
2,NAI-14-132,62,2018-06-07,158,5,17,1,241,"Adelphia, NJ"
3,NAI-14-132,62,2018-07-17,198,5,17,1,241,"Adelphia, NJ"
4,NAI-14-132,62,2018-08-15,227,6,17,1,241,"Adelphia, NJ"
...,...,...,...,...,...,...,...,...,...
9607,After Midnight,16,2021-07-13,194,6,1,15,15,"Adelphia, NJ"
9608,After Midnight,16,2021-08-11,223,6,1,15,15,"Adelphia, NJ"
9609,After Midnight,16,2021-09-17,260,5,1,15,15,"Adelphia, NJ"
9610,After Midnight,16,2021-10-13,286,5,1,15,15,"Adelphia, NJ"


In [13]:
plot_effects = pd.DataFrame(fit.plot_effect.mean(axis=0), columns=['plot_effect'])
plot_effects['plt_id'] = raw_data.plt_id.sort_values().unique()
plot_effects = plot_effects.merge(raw_data[['plt_id', 'row', 'col', 'entry_name']].drop_duplicates(), on='plt_id', how='left')
plot_effects.to_csv(ROOT_DIR/"data/processed/nj2_seasonality_plot_effects.csv")
plot_effects.head()

Unnamed: 0,plot_effect,plt_id,row,col,entry_name
0,0.154937,1,1,1,A11-40
1,0.226737,2,1,2,A13-1
2,-0.058444,3,1,3,A99-2897
3,0.076843,4,1,4,Syrah (LTP-11-41)
4,0.002288,5,1,5,Blue Knight


In [14]:
fig = px.imshow(plot_effects.pivot(index='row', columns='col', values='plot_effect'), aspect='equal',
                labels={'color': 'Plot Effect'},)

# Add hovertemplate to include entry_name
hovertemplate = 'Row: %{y}<br>Colum: %{x}<br>Plot Effect: %{z:0.3f}<br>Entry Name: %{customdata}'
fig.update_traces(customdata=plot_effects.pivot(index='row', columns='col', values='entry_name').values,
                  hovertemplate=hovertemplate)
fig.show()

# 3. Radar plot for all seven locations at a pre-selected day of year. 
Note that this is just the prediction at a pre-selected `adj_day_of_year`, I want to have a slider that they can slide over all values of `adj_day_of_year` such that users can see entry performances over all locations.

In [15]:
entry_effects = pd.read_csv(ROOT_DIR/"data/processed/entry_effects_all_locations.csv")
entry_effects.set_index('ENTRY_NAME', inplace=True)
entry_effects = entry_effects.sort_values(by='St. Paul, MN', ascending=False)
entry_effects

Unnamed: 0_level_0,"St. Paul, MN","East Lansing, MI","Logan, UT","West Lafayette, IN","Adelphia, NJ","Stillwater, OK","Raleigh, NC"
ENTRY_NAME,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
Bombay (GO-22B23),0.734971,0.909240,0.438874,0.625752,0.962912,0.462082,0.575651
A10-280,0.607321,0.270730,-0.244933,0.321810,0.352431,0.697842,0.630751
Cloud (GO-2425),0.599630,0.492729,0.446384,0.622543,0.755832,0.428817,0.589300
Star (GO-2628),0.524738,0.528770,0.450669,0.687496,0.913741,0.481949,0.343271
Yellowstone (A12-7),0.510728,0.041446,0.086953,-0.004896,-0.579768,-0.035210,-0.148110
...,...,...,...,...,...,...,...
PPG-KB 1131,-0.496659,0.126869,0.205867,-0.053312,0.184798,0.273629,0.436910
Blue Knight,-0.598218,-1.395883,-0.284597,-1.083766,-1.505431,-0.164530,-1.172004
BAR PP 71213,-0.677184,0.952302,-0.484161,-0.128568,-0.051917,0.282594,-0.002125
DLFPS-340/3364,-0.879439,-1.131112,-0.520079,-0.564922,-0.732689,0.208770,-0.127599


In [16]:
fig = go.Figure()
for entry_name, entry_eff in entry_effects.iterrows(): # 
    fig.add_trace(go.Scatterpolar(
          r=entry_eff.tolist() + entry_eff.tolist()[:1], # repeat the first value to close line
          theta=entry_effects.columns.tolist() + entry_effects.columns.tolist()[:1],
          fill='toself',
          name=entry_name,
          hovertemplate = "score:%{r:.2f}" + f'<br>Entry: {entry_name}<br><extra></extra>'
    ))
    
fig.update_layout(
#   title='Estimated Turf Performance of Cultivars in 2017 NTEP KBG',
  polar=dict(
    angularaxis=dict(
    showgrid=True),
    radialaxis=dict(
        showgrid=True,
        visible=True,
        dtick=1,
        range=[-2.5, 2]
    )),
  showlegend=True
)


fig.show()