In [1]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from plotly import graph_objects as go
from plotly import colors
from plotly.colors import sample_colorscale

In [2]:
import sys
import os

# Add the parent directory to sys.path
sys.path.append(os.path.abspath('..'))

In [3]:
from src.table import CherryTable
import src.utils as utils
import src.preprocessing as pre

In [4]:
train, test, mapping_dict = pre.separate_year(planting_meta_path='../data/planting_meta.json', weekly_summary_path='../data/weekly_summary.csv')

In [5]:
model =utils.train_harvest_model(train)

Epoch [1/5], Loss: 869098.8096
Epoch [2/5], Loss: 799709.9590
Epoch [3/5], Loss: 792688.7305
Epoch [4/5], Loss: 762202.8954
Epoch [5/5], Loss: 710263.5054


In [6]:
predictions = utils.predict_harvest(model, test)

In [7]:
meta = pre.decode(test, mapping_dict)
actuals = test.Y_kilos.detach().numpy()

In [8]:
table = CherryTable(meta, {'predictions':predictions}, actuals)

In [9]:
a,d,ha_info = table.graph_ready(ranches=True,classes=True)

In [10]:
title = list(d.keys())[0]

In [11]:
d[title]

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
Ranch,Class,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
OAP,BSUF,-5.484486,-9.888428,-145.414291,61.635513,598.579346,-2.569424,-27.74655,72.740891,107.973282,1810.401489,4816.266602,11186.5625,15852.775391,20770.382812,24056.412109,23987.316406,22440.328125,19080.25,15240.095703,8298.606445
OAP,CHE,2.45681,7.943494,-110.876961,25.5982,533.327026,-29.910587,-29.806204,32.443451,15.004826,1266.989624,3784.985596,8272.390625,12201.799805,16853.367188,19649.8125,19700.138672,18235.365234,15363.193359,11833.248047,6383.07666
SFB,BSUF,-0.968893,-2.58451,-53.231823,20.548948,220.215179,-1.542894,-10.530596,25.854286,37.041763,656.682251,1755.427002,4065.805908,5771.964355,7585.699707,8789.068359,8766.237305,8197.151367,6965.821777,5551.927246,3022.415283
SFB,CHE,6.975963,14.065204,-56.760666,-1.648409,309.74411,-28.174015,-18.22612,5.046189,-27.60984,585.230957,1967.473755,4055.883789,6221.390137,9020.082031,10576.40332,10653.318359,9774.263672,8169.068848,6085.078613,3254.897461
SGB,BSUF,-1.744256,-3.504397,-77.44474,29.990992,324.569458,-2.857521,-13.151806,37.348942,52.098984,953.347168,2565.803467,5925.062988,8428.082031,11104.390625,12867.47168,12839.71875,11997.331055,10194.018555,8110.013184,4411.467285
SGB,CHE,2.090929,5.061502,-56.999886,11.244286,276.072754,-16.086643,-13.631971,15.750303,4.52905,640.73761,1934.303589,4207.97998,6225.786621,8638.094727,10073.3125,10105.269531,9345.362305,7870.705078,6042.613281,3256.337891
SJB,BSUF,-0.740367,-1.686371,-42.077534,15.93392,177.131393,-2.021791,-7.685522,19.836365,26.772032,513.898499,1391.111816,3202.841553,4565.09082,6032.45166,6992.543457,6979.362305,6518.500488,5535.058594,4394.73291,2389.801025
SJB,CHE,4.001469,8.502715,-49.061676,3.778631,253.57753,-18.968668,-12.7232,9.334312,-11.505804,524.174683,1677.529663,3545.6875,5348.45166,7602.751953,8889.905273,8939.452148,8231.4375,6904.257324,5212.570312,2797.501465
SMB,BSUF,-2.673324,-6.069169,-90.310112,37.012173,369.729218,-1.681117,-17.151976,44.276539,63.896141,1109.29541,2960.208008,6865.385254,9738.205078,12782.999023,14805.984375,14767.655273,13810.643555,11739.398438,9364.745117,5097.76709
SMB,CHE,3.24902,6.576526,-53.011383,7.103358,263.437744,-17.511181,-14.336872,12.091917,-5.027208,577.258484,1793.26062,3843.761963,5743.07959,8068.996094,9423.833984,9465.50293,8734.083008,7338.564941,5586.088379,3003.856445


In [12]:
df = d[title].T

In [13]:
af = d['actuals_summed'].T

In [14]:
af

Ranch,OAP,OAP,SFB,SFB,SGB,SGB,SJB,SJB,SMB,SMB,VAP,VAP
Class,BSUF,CHE,BSUF,CHE,BSUF,CHE,BSUF,CHE,BSUF,CHE,BSUF,CHE
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,0.0,0.0,129.0,213.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
6,247.0,0.0,0.0,1333.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
7,0.0,0.0,852.0,883.0,33.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
8,0.0,796.0,2724.0,2004.0,116.0,36.0,1482.0,507.0,480.0,0.0,24.0,0.0
9,1280.0,2267.0,0.0,0.0,307.0,108.0,4236.0,742.0,0.0,1023.0,587.0,40.0


In [15]:
fig = go.Figure()
x_values = df.index
colors = sample_colorscale('Viridis', [i / (len(df.columns) - 1) for i in range(len(df.columns))])
color_map = {key: colors[i] for i, key in enumerate(df.columns)}

for i, col in enumerate(df.columns):
    label = col[0] + ': ' + col[1]
    fig.add_trace(go.Scatter(x=x_values, y=df[col], name=label + ' Predicted', line=dict(color=color_map[col])))
    fig.add_trace(go.Bar(x=x_values, y=af[col], name= label + ' Actual', marker_color=color_map[col], opacity=0.75))

visibility_matrix = np.eye(2* len(df.columns), dtype=bool).repeat(2, axis=1)

buttons = [dict(label=f"{key[0]}: {key[1]}",method='update',args=[{'visible': visibility_matrix[i].tolist()}, {'title': f'Selected: {key[0]}: {key[1]}'}]) for i,key in enumerate(df.columns)]


# Add "Show All" button
buttons.append(dict(
    label='Show All', 
    method='update',
    args=[{'visible': [True]*(len(df.columns)*2)}, {'title': 'All Traces Visible'}]
))

# Add dropdown menu to layout
fig.update_layout(
    updatemenus=[dict(
        active=0,
        buttons=buttons,
        direction='down',
        showactive=True,
        x=1.05,
        xanchor='left',
        y=1,
        yanchor='top'
    )],
    title=title + ' Trace Filter with Dropdown',
    template='plotly_white',
    xaxis_title='Week after Transplant',
    yaxis_title='Harvest (kg)'
)

fig.show()


In [16]:
import src.curve_alter as curve_alter

In [17]:
k = 10

In [18]:
actuals = af.T.values[0][:k]

In [19]:
preds =df.T.values[0]

array([-1.10778678e+05,  2.34142721e+07,  1.42642592e+08, -1.99106100e+10,
       -1.22840958e+09,  1.90660661e+11,  7.18414943e+12, -1.53270185e+14,
        3.69369398e+16, -1.41234620e+18,  4.71489403e+19, -9.60340455e+20,
        1.80846061e+22, -3.01050978e+23,  4.31454980e+24, -5.80133119e+25,
        7.08968847e+26, -8.13908194e+27,  6.36997210e+28])