In [1]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from plotly import graph_objects as go
from plotly import colors
from plotly.colors import sample_colorscale

In [2]:
import sys
import os

# Add the parent directory to sys.path
sys.path.append(os.path.abspath('..'))

In [3]:
from src.table import CherryTable
import src.utils as utils
import src.preprocessing as pre

In [4]:
train, test, mapping_dict = pre.separate_year(planting_meta_path='../data/planting_meta.json', weekly_summary_path='../data/weekly_summary.csv')

In [5]:
model =utils.train_harvest_model(train)

Epoch [1/5], Loss: 837799.2586
Epoch [2/5], Loss: 800551.9652
Epoch [3/5], Loss: 794964.5982
Epoch [4/5], Loss: 769445.6730
Epoch [5/5], Loss: 714376.1796


In [6]:
predictions = utils.predict_harvest(model, test)

In [7]:
meta = pre.decode(test, mapping_dict)
actuals = test.Y_kilos.detach().numpy()

In [8]:
table = CherryTable(meta, {'predictions':predictions}, actuals)

In [9]:
a,d,ha_info = table.graph_ready(ranches=True,classes=True)

In [10]:
title = list(d.keys())[0]

In [11]:
d[title]

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
Ranch,Class,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
OAP,BSUF,4.31699,-0.356808,7.236816,-3.727654,4.017518,186.897842,-23.817627,127.639946,268.985382,1181.293945,3924.220703,8646.592773,12424.475586,17929.826172,19206.568359,19806.332031,17934.896484,15447.476562,12053.061523,6032.179199
OAP,CHE,-9.002924,2.536819,16.079273,-0.373222,-126.60804,262.015015,-14.808757,58.146477,169.493652,1005.70343,3772.435791,7930.842773,11745.31543,16996.390625,18312.982422,18951.648438,17178.904297,14436.757812,11332.753906,5691.674316
SFB,BSUF,0.184981,0.426452,3.739281,-0.31429,-14.789971,74.882851,-6.410207,35.42482,81.613914,392.267883,1355.246094,2939.635498,4268.192871,6164.77002,6618.502441,6833.330566,6190.831543,5285.952148,4132.929688,2071.364746
SFB,CHE,-11.051155,3.601353,13.794744,3.287706,-140.054321,194.521393,-1.906226,-0.604939,53.501953,521.745605,2200.857178,4435.632812,6757.138672,9801.185547,10619.804688,11023.788086,10005.254883,8219.796875,6488.416016,3270.770508
SGB,BSUF,2.214371,-0.157831,4.311312,-1.021194,-3.763134,106.265991,-11.204788,66.078766,142.93306,642.324585,2151.467529,4723.592773,6805.334961,9823.209961,10529.787109,10861.494141,9838.618164,8453.385742,6599.384277,3304.352051
SGB,CHE,-5.386551,1.296161,9.072071,0.832655,-76.093681,140.263245,-5.452521,22.473885,76.592896,496.821442,1906.464355,3972.817139,5919.606445,8570.81543,9246.806641,9575.30957,8683.989258,7259.983398,5705.834473,2868.196045
SJB,BSUF,0.385468,0.595691,2.788979,-0.06665,-11.536199,62.059765,-5.20028,30.069008,69.203979,329.280518,1135.046631,2464.679932,3576.049316,5164.634766,5544.788086,5724.050781,5186.167969,4430.182617,3463.237305,1735.925781
SJB,CHE,-7.391114,3.034059,9.702473,2.763089,-100.565376,147.280762,-1.63016,4.429385,49.736313,422.960327,1741.526611,3540.22168,5362.467773,7774.291016,8416.575195,8730.796875,7923.563477,6537.481934,5154.464355,2597.275391
SMB,BSUF,0.578284,0.267277,6.035095,-1.621244,-17.216637,122.51416,-12.223793,65.717491,144.893692,677.541748,2314.688232,5043.609863,7299.981445,10541.130859,11308.856445,11671.748047,10572.533203,9050.833008,7072.157227,3542.848633
SMB,CHE,-6.260203,1.753505,9.272972,0.790629,-81.982178,141.718262,-4.9336,19.261028,71.031242,480.951447,1874.049683,3885.215332,5807.877441,8410.885742,9080.342773,9406.34668,8531.004883,7114.069336,5594.873047,2813.837891


In [12]:
df = d[title].T

In [13]:
af = d['actuals_summed'].T

In [14]:
af

Ranch,OAP,OAP,SFB,SFB,SGB,SGB,SJB,SJB,SMB,SMB,VAP,VAP
Class,BSUF,CHE,BSUF,CHE,BSUF,CHE,BSUF,CHE,BSUF,CHE,BSUF,CHE
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,0.0,0.0,129.0,213.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,247.0,0.0,0.0,1333.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,0.0,0.0,852.0,883.0,33.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,0.0,796.0,2724.0,2004.0,116.0,36.0,1482.0,507.0,480.0,0.0,24.0,0.0
9,1280.0,2267.0,0.0,0.0,307.0,108.0,4236.0,742.0,0.0,1023.0,587.0,40.0


In [15]:
fig = go.Figure()
x_values = df.index
colors = sample_colorscale('Viridis', [i / (len(df.columns) - 1) for i in range(len(df.columns))])
color_map = {key: colors[i] for i, key in enumerate(df.columns)}

for i, col in enumerate(df.columns):
    label = col[0] + ': ' + col[1]
    fig.add_trace(go.Scatter(x=x_values, y=df[col], name=label + ' Predicted', line=dict(color=color_map[col])))
    fig.add_trace(go.Bar(x=x_values, y=af[col], name= label + ' Actual', marker_color=color_map[col], opacity=0.75))

visibility_matrix = np.eye(2* len(df.columns), dtype=bool).repeat(2, axis=1)

buttons = [dict(label=f"{key[0]}: {key[1]}",method='update',args=[{'visible': visibility_matrix[i].tolist()}, {'title': f'Selected: {key[0]}: {key[1]}'}]) for i,key in enumerate(df.columns)]


# Add "Show All" button
buttons.append(dict(
    label='Show All', 
    method='update',
    args=[{'visible': [True]*(len(df.columns)*2)}, {'title': 'All Traces Visible'}]
))

# Add dropdown menu to layout
fig.update_layout(
    updatemenus=[dict(
        active=0,
        buttons=buttons,
        direction='down',
        showactive=True,
        x=1.05,
        xanchor='left',
        y=1,
        yanchor='top'
    )],
    title=title + ' Trace Filter with Dropdown',
    template='plotly_white',
    xaxis_title='Week after Transplant',
    yaxis_title='Harvest (kg)'
)

fig.show()


In [16]:
import src.graphs as graphs
import src.finetuning as finetuning
fig = graphs.graph_preds_reg(df, af, title)