In [1]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from plotly import graph_objects as go
from plotly import colors
from plotly.colors import sample_colorscale

In [2]:
import sys
import os

# Add the parent directory to sys.path
sys.path.append(os.path.abspath('..'))

In [3]:
from src.table import CherryTable
import src.utils as utils
import src.preprocessing as pre

In [4]:
train, test, mapping_dict = pre.separate_year(planting_meta_path='../data/planting_meta.json', weekly_summary_path='../data/weekly_summary.csv')

In [5]:
model =utils.train_harvest_model(train)

Epoch [1/5], Loss: 832104.0675
Epoch [2/5], Loss: 800848.1241
Epoch [3/5], Loss: 797024.7587
Epoch [4/5], Loss: 776831.1709
Epoch [5/5], Loss: 723560.6772


In [6]:
predictions = utils.predict_harvest(model, test)

In [7]:
meta = pre.decode(test, mapping_dict)
actuals = test.Y_kilos.detach().numpy()

In [8]:
table = CherryTable(meta, {'predictions':predictions}, actuals)

In [9]:
a,d,ha_info = table.graph_ready(ranches=True,classes=True)

In [10]:
title = list(d.keys())[0]

In [11]:
d[title]

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
Ranch,Class,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
OAP,BSUF,141.580505,42.129002,-79.00135,-33.188404,48.552689,-66.730507,118.187881,3.885859,302.009491,783.393372,2702.591553,5513.274902,8839.570312,12461.438477,14384.460938,14022.672852,13443.606445,11116.978516,8240.208984,4658.027832
OAP,CHE,159.264954,50.941032,-90.764786,-40.131687,56.782154,-84.152176,136.955414,6.0727,352.163025,894.482788,3112.199219,6369.236328,10195.237305,14443.368164,16715.9375,16229.042969,15597.234375,12891.987305,9549.679688,5376.845703
SFB,BSUF,58.401432,13.797932,-31.918432,-12.251407,20.116072,-20.774715,46.968842,0.8881,118.596527,319.073792,1085.886719,2203.264893,3541.764648,4942.098633,5673.289062,5579.597656,5321.032227,4405.35791,3269.047852,1862.490356
SFB,CHE,81.522392,43.660824,-60.48473,-20.98205,31.24098,-83.639877,72.536598,5.593054,209.132568,481.655273,1759.043823,3643.742676,5769.070801,8433.985352,9924.646484,9391.821289,9175.882812,7575.311523,5586.292969,3077.770264
SGB,BSUF,69.022247,27.110188,-45.003433,-15.061691,24.518091,-47.658936,57.550846,2.470041,157.771118,392.048645,1383.017456,2834.937988,4520.974609,6466.199219,7521.58252,7246.367188,7001.022949,5788.794434,4280.814453,2396.871826
SGB,CHE,72.102272,30.350292,-48.666756,-18.05534,26.563469,-55.473877,61.94725,3.234985,170.700287,415.112579,1480.420776,3045.930176,4846.888672,6976.883789,8143.666016,7804.391113,7564.089844,6251.09375,4618.970703,2573.273438
SJB,BSUF,41.333935,13.841717,-25.561249,-8.952772,14.741861,-24.199017,33.986065,1.48371,90.901726,231.670746,807.014832,1648.937744,2635.713135,3741.425049,4334.888672,4202.818359,4044.650635,3345.279297,2476.437256,1393.941528
SJB,CHE,63.563927,32.523899,-48.281998,-16.077612,24.387781,-64.45237,55.073399,4.447614,161.029633,373.225922,1360.451416,2815.167969,4455.970703,6509.052734,7655.687012,7251.572266,7082.646484,5849.21875,4312.89502,2378.250244
SMB,BSUF,95.801613,23.925289,-52.277828,-20.871101,33.071827,-35.307899,77.632919,2.084085,196.164612,525.432007,1790.964844,3637.037109,5845.034668,8169.209961,9385.200195,9217.574219,8797.381836,7281.737305,5402.238281,3074.310303
SMB,CHE,77.725571,28.724817,-49.175022,-18.86396,28.607771,-50.318039,66.018082,3.618845,177.603775,442.642456,1557.934204,3194.831299,5097.969727,7278.525879,8458.547852,8161.147461,7876.291504,6511.098633,4816.222656,2698.614014


In [12]:
d.keys()

dict_keys(['predictions_summed', 'predictions_summed_cumsum', 'predictions_summed_cumprop', 'actuals_summed', 'actuals_summed_cumsum', 'actuals_summed_cumprop'])

In [13]:
df = d[title].T

In [14]:
af = d['actuals_summed'].T

In [15]:
df

Ranch,OAP,OAP,SFB,SFB,SGB,SGB,SJB,SJB,SMB,SMB,VAP,VAP
Class,BSUF,CHE,BSUF,CHE,BSUF,CHE,BSUF,CHE,BSUF,CHE,BSUF,CHE
0,141.580505,159.264954,58.401432,81.522392,69.022247,72.102272,41.333935,63.563927,95.801613,77.725571,414.309906,323.025757
1,42.129002,50.941032,13.797932,43.660824,27.110188,30.350292,13.841717,32.523899,23.925289,28.724817,72.713585,93.557816
2,-79.00135,-90.764786,-31.918432,-60.48473,-45.003433,-48.666756,-25.561249,-48.281998,-52.277828,-49.175022,-223.661865,-202.50943
3,-33.188404,-40.131687,-12.251407,-20.98205,-15.061691,-18.05534,-8.952772,-16.077612,-20.871101,-18.86396,-88.838608,-76.537018
4,48.552689,56.782154,20.116072,31.24098,24.518091,26.563469,14.741861,24.387781,33.071827,28.607771,140.03215,115.183113
5,-66.730507,-84.152176,-20.774715,-83.639877,-47.658936,-55.473877,-24.199017,-64.45237,-35.307899,-50.318039,-108.721626,-173.344757
6,118.187881,136.955414,46.968842,72.536598,57.550846,61.94725,33.986065,55.073399,77.632919,66.018082,320.684967,261.649994
7,3.885859,6.0727,0.8881,5.593054,2.470041,3.234985,1.48371,4.447614,2.084085,3.618845,3.150698,9.999574
8,302.009491,352.163025,118.596527,209.132568,157.771118,170.700287,90.901726,161.029633,196.164612,177.603775,802.081848,699.446716
9,783.393372,894.482788,319.073792,481.655273,392.048645,415.112579,231.670746,373.225922,525.432007,442.642456,2217.330078,1792.451416


In [16]:
af

Ranch,OAP,OAP,SFB,SFB,SGB,SGB,SJB,SJB,SMB,SMB,VAP,VAP
Class,BSUF,CHE,BSUF,CHE,BSUF,CHE,BSUF,CHE,BSUF,CHE,BSUF,CHE
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,0.0,0.0,129.0,213.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,247.0,0.0,0.0,1333.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,0.0,0.0,852.0,883.0,33.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,0.0,796.0,2724.0,2004.0,116.0,36.0,1482.0,507.0,480.0,0.0,24.0,0.0
9,1280.0,2267.0,0.0,0.0,307.0,108.0,4236.0,742.0,0.0,1023.0,587.0,40.0


In [17]:
fig = go.Figure()
x_values = df.index
colors = sample_colorscale('Viridis', [i / (len(df.columns) - 1) for i in range(len(df.columns))])
color_map = {key: colors[i] for i, key in enumerate(df.columns)}

for i, col in enumerate(df.columns):
    label = col[0] + ': ' + col[1]
    fig.add_trace(go.Scatter(x=x_values, y=df[col], name=label + ' Predicted', line=dict(color=color_map[col])))
    fig.add_trace(go.Bar(x=x_values, y=af[col], name= label + ' Actual', marker_color=color_map[col], opacity=0.75))

visibility_matrix = np.eye(2* len(df.columns), dtype=bool).repeat(2, axis=1)

buttons = [dict(label=f"{key[0]}: {key[1]}",method='update',args=[{'visible': visibility_matrix[i].tolist()}, {'title': f'Selected: {key[0]}: {key[1]}'}]) for i,key in enumerate(df.columns)]


# Add "Show All" button
buttons.append(dict(
    label='Show All', 
    method='update',
    args=[{'visible': [True]*(len(df.columns)*2)}, {'title': 'All Traces Visible'}]
))

# Add dropdown menu to layout
fig.update_layout(
    updatemenus=[dict(
        active=0,
        buttons=buttons,
        direction='down',
        showactive=True,
        x=1.05,
        xanchor='left',
        y=1,
        yanchor='top'
    )],
    title=title + ' Trace Filter with Dropdown',
    template='plotly_white',
    xaxis_title='Week after Transplant',
    yaxis_title='Harvest (kg)'
)

fig.show()


In [18]:
import src.graphs as graphs
fig = graphs.graph_preds_reg(df, af, title)

In [19]:
fig.show()