In [None]:
! pip install --quiet wandb torchinfo pytorch-lightning==1.6.5 pytorch-forecasting==0.10.1 statsmodels==0.13.2

In [None]:
!git clone https://github.com/B-Deforce/Global-Local-Interpretable-Forecasting-through-Information-Fusion-in-Smart-Agriculture.git

# Imports

In [None]:
# standard imports
import sys
import numpy as np
import wandb
import json

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pytorch_forecasting import TemporalFusionTransformer
from torchinfo import summary
import pytorch_lightning as pl
from sklearn.preprocessing import MinMaxScaler
from scipy import stats
import matplotlib.pyplot as plt

In [None]:
repo = "/content/Global-Local-Interpretable-Forecasting-through-Information-Fusion-in-Smart-Agriculture"
sys.path.append(repo)

In [None]:
# custom imports
from configs.config import CFG, sweep_CFG
from utils.config import Config
from dataloader.dataloader import DataLoader
from model.var import MultiVAR
from model.tft import TFT
from model.lstm import LSTM_dataprep, SoilMoist_LSTM, SoilMoistPredictor
from utils.metrics import Metrics

In [None]:
data_config = Config.from_json(CFG).data
model_config = Config.from_json(CFG).model
sweep_config = json.loads(json.dumps(sweep_CFG))

# Read data

! Note that the authors do not have permission to publicly share the data !

In [None]:
df = DataLoader.load_data(data_config)

# Models

## VAR model

### Initiate model

In [None]:
var = MultiVAR(model_config.var_model, df)
var.load_data()

In [None]:
var.dataset

### Fit model

In [None]:
fitted = var.train()

### Predict

In [None]:
var_results = var.predict(fitted)

## LSTM model

### Initiate model

In [None]:
lstm_dataprep = LSTM_dataprep(model_config.lstm_model, full_df_pure=df)
training, train_dataloader, val_dataloader, test_dataloader = lstm_dataprep.build_dataset(
    train_val_combo=True
    )

In [None]:
lstm_model = SoilMoistPredictor(model_config.lstm_model)

### Fit model

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    verbose=True,
    mode='min',
    save_top_k=1, # saves best 3 models
    filename='2nd_round-{epoch}-{step}-{val_loss:.2f}'    
)

early_stop_callback = EarlyStopping(
    monitor="val_loss", 
    min_delta=1e-4, 
    patience=40, 
    verbose=False, 
    mode="min"
    )

#logger=WandbLogger(project="SoilMoist_LSTM", id="122mp3he", resume="must")
logger=WandbLogger(project="SoilMoist_LSTM")

trainer = pl.Trainer(
    logger=logger,
    callbacks=[early_stop_callback, checkpoint_callback],
    max_epochs=100,
    accelerator="gpu",
    devices=1,
)

In [None]:
trainer.fit(
    model, 
    train_dataloaders=train_dataloader,
    val_dataloaders=test_dataloader,
    )

In [None]:
wandb.save(trainer.checkpoint_callback.best_model_path)

### Load best model

In [None]:
lstm_model = SoilMoistPredictor.load_from_checkpoint(f"{repo}/best_lstm_model.ckpt", 
                                                model_config=model_config.lstm_model)
lstm_model.freeze()

### Predict

In [None]:
lstm_pred = {
    "x": pd.DataFrame(),
    "y_hat": [],
    "y_true": [],
}
for batch in test_dataloader:
  x, y = batch
  pred = lstm_model(x["encoder_cont"])
  lstm_pred["x"] = lstm_pred["x"].append(test_dataloader.dataset.x_to_index(x))
  lstm_pred["y_true"].append(y[0].numpy().squeeze())
  lstm_pred["y_hat"].append(pred[1].detach().numpy())

lstm_pred["y_true"] = np.concatenate(lstm_pred["y_true"], axis=0)
lstm_pred["y_hat"] = np.concatenate(lstm_pred["y_hat"], axis=0)

In [None]:
# last known value
raw_sm = []
for _, i in lstm_pred["x"].iterrows():
  # we take the last known value of soil_moisture (i.e. time_idx)
  raw_sm.append(df[((df["sensor_id"] == i[1]) & (df["measurement_year"] == i[2]))]["soil_moisture"].iloc[i[0]-1])

raw_sm = np.array(raw_sm).reshape(-1,1)

In [None]:
descaler = MinMaxScaler()
descaler.min_, descaler.scale_ = lstm_dataprep.minmax_scaler.min_[-1], lstm_dataprep.minmax_scaler.scale_[-1]

In [None]:
lstm_ypred = DataLoader.descale(descaler, lstm_pred["y_hat"])
lstm_ypred = np.array([np.sum(lstm_ypred[:, :k+1], axis=1) for k in range(lstm_ypred.shape[1])]).T + raw_sm

In [None]:
lstm_ytrue = DataLoader.descale(descaler, lstm_pred['y_true'])
lstm_ytrue = np.array([np.sum(lstm_ytrue[:, :k+1], axis=1) for k in range(lstm_ytrue.shape[1])]).T + raw_sm

## TFT

### Sweep

In [None]:
sweep_id = wandb.sweep(sweep_config, project="TFT_scaled_sweep")

In [None]:
wandb.agent(sweep_id=sweep_id, function=sweepstaker, count=50)

### Initiate model

In [None]:
tft = TFT(model_config.tft_model, full_df_pure=df)

In [None]:
# setup data
training, train_dataloader, val_dataloader, test_dataloader = tft.build_dataset(
    train_val_combo=True,
    )

### Fit model

In [None]:
wandb_logger = WandbLogger()

# setup callbacks
early_stop_callback = EarlyStopping(monitor="val_MAE", 
                                    min_delta=1e-5, 
                                    patience=50, 
                                    verbose=False, 
                                    mode="min"
                                    )
model_callback = ModelCheckpoint(
    monitor="val_MAE",
    verbose=True,
    mode='min',
    save_top_k=1, # saves best 3 models
    filename='{epoch}-{step}-{val_MAE:.4f}'
    )

# setup model - note how we refer to sweep parameters with wandb.config
trainer, model = tft.build_model(
    training, 
    wandb_logger, 
    callbacks=[early_stop_callback, model_callback],
)

In [None]:
# train
# refit on full train and val dataloader
trainer.fit(
    model, 
    train_dataloaders=train_dataloader, 
    val_dataloaders=test_dataloader,
    )

wandb.save(trainer.checkpoint_callback.best_model_path)
wandb.finish()

### Load best model

In [None]:
# load the best model
best_tft = f"{repo}/best_tft_model.ckpt"
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_tft)
best_tft.freeze()

### Predict

In [None]:
raw_pred, x = best_tft.predict(test_dataloader, 
                                return_x=True,
                                mode="raw"
                                )

In [None]:
# last known value
raw_sm = []
for _, i in test_dataloader.dataset.x_to_index(x).iterrows():
  # we take the last known value of soil_moisture (i.e. time_idx)
  raw_sm.append(df[((df["sensor_id"] == i[1]) & (df["measurement_year"] == i[2]))]["soil_moisture"].iloc[i[0]-1])

raw_sm = np.array(raw_sm).reshape(-1,1)

In [None]:
sm_diff_idx = test_dataloader.dataset.reals.index('soil_moisture_diff_scaled')

In [None]:
target_descaler = MinMaxScaler()
target_descaler.min_, target_descaler.scale_ = tft.minmax_scaler.min_[5], tft.minmax_scaler.scale_[5]

In [None]:
sm_descaler = MinMaxScaler()
sm_descaler.min_, sm_descaler.scale_ = tft.minmax_scaler.min_[0], tft.minmax_scaler.scale_[0]

In [None]:
tft_ypred = raw_pred["prediction"][:,:,2].numpy()
tft_ypred = DataLoader.descale(target_descaler, tft_ypred)
tft_ypred = np.array([np.sum(tft_ypred[:, :k+1], axis=1) for k in range(tft_ypred.shape[1])]).T + raw_sm

In [None]:
tft_ytrue = x["decoder_target"].numpy()
tft_ytrue = DataLoader.descale(target_descaler, tft_ytrue)
tft_ytrue = np.array([np.sum(tft_ytrue[:, :k+1], axis=1) for k in range(tft_ytrue.shape[1])]).T + raw_sm

In [None]:
sm_quantiles = {}
for i in range(raw_pred["prediction"].shape[-1]):
   sm_q = raw_pred["prediction"][:,:,i]
   sm_quantiles[f"quantile_{i}"] = sm_q

# Evaluation

## Metrics

### MAE

In [None]:
# Naive
# we use tft_ytrue but could also use lstm/var_ytrue
naive_mae = np.abs(np.broadcast_to(raw_sm, shape=(353,5)) - tft_ytrue).mean(axis=1)
print(f"Naive median MAE: {np.median(naive_mae)}")
print(f"Naive MAE iqr: {stats.iqr(naive_mae)}")

# VAR
var_mae = [np.abs(i[0] - i[1]).mean() for i in var_results]
print(f"VAR median MAE: {np.median(var_mae)}")
print(f"VAR MAE iqr: {stats.iqr(var_mae)}")

# LSTM
lstm_mae = np.abs(lstm_ypred - lstm_ytrue).mean(axis=1)
print(f"LSTM median MAE: {np.median(lstm_mae)}")
print(f"LSTM MAE iqr: {stats.iqr(lstm_mae)}")

# TFT
tft_mae = np.abs(tft_ypred - tft_ytrue).mean(axis=1)
print(f"TFT median MAE: {np.median(tft_mae)}")
print(f"TFT MAE iqr: {stats.iqr(tft_mae)}")

### RMSE

In [None]:
# Naive
# we use tft_ytrue but could also use lstm/var_ytrue
naive_rmse = np.sqrt(((np.broadcast_to(raw_sm, shape=(353,5)) - tft_ytrue)**2).mean(axis=1))
print(f"Naive median RMSE: {np.median(naive_rmse)}")
print(f"Naive RMSE iqr: {stats.iqr(naive_rmse)}")

# VAR
var_rmse = [np.sqrt(((i[0] - i[1])**2).mean()) for i in var_results]
print(f"VAR median RMSE: {np.median(var_rmse)}")
print(f"VAR RMSE iqr: {stats.iqr(var_rmse)}")

# LSTM
lstm_rmse = np.sqrt(((lstm_ypred - lstm_ytrue)**2).mean(axis=1))
print(f"LSTM median RMSE: {np.median(lstm_rmse)}")
print(f"LSTM RMSE iqr: {stats.iqr(lstm_rmse)}")

# TFT
tft_rmse = np.sqrt(((tft_ypred - tft_ytrue)**2).mean(axis=1))
print(f"TFT median RMSE: {np.median(tft_rmse)}")
print(f"TFT RMSE iqr: {stats.iqr(tft_rmse)}")

### MDA

In [None]:
# VAR
var_mda = [MDA(i[0].reshape(1, -1), i[1].reshape(1, -1)) for i in var_results]
var_mda_ci = 1.96*np.std(var_mda)/np.sqrt(len(var_mda))
print(f"VAR mean MDA: {np.mean(var_mda)}")
print(f"VAR MDA CI: {var_mda_ci}")

# LSTM
lstm_mda = MDA(lstm_ypred, lstm_ytrue)
lstm_mda_ci = 1.96*np.std(lstm_mda)/np.sqrt(len(lstm_mda))
print(f"LSTM mean MDA: {np.mean(lstm_mda)}")
print(f"LSTM MDA CI: {lstm_mda_ci}")

# TFT
tft_mda = MDA(tft_ypred, tft_ytrue)
tft_mda_ci = 1.96*np.std(tft_mda)/np.sqrt(len(tft_mda))
print(f"TFT mean MDA: {np.mean(tft_mda)}")
print(f"TFT MDA CI: {tft_mda_ci}")

### Q-RISK

In [None]:
# Naive
# we use tft_ytrue but could also use lstm/var_ytrue
naive_qrisk_05 = Metrics.q_risk(np.broadcast_to(raw_sm, shape=(353,5)), tft_ytrue, q=0.5)
naive_qrisk_09 = Metrics.q_risk(np.broadcast_to(raw_sm, shape=(353,5)), tft_ytrue, q=0.9)
print(f"Naive q-risk .5 - .9: {naive_qrisk_05} - {naive_qrisk_09}")

# VAR
var_ypred = np.vstack([i[0] for i in var_results])
var_ytrue = np.vstack([i[1] for i in var_results])
naive_qrisk_05 = Metrics.q_risk(var_ypred, var_ytrue, q=0.5)
naive_qrisk_09 = Metrics.q_risk(var_ypred, var_ytrue, q=0.9)
print(f"VAR q-risk .5 - .9: {naive_qrisk_05} - {naive_qrisk_09}")

# LSTM
lstm_qrisk_05 = Metrics.q_risk(lstm_ypred, lstm_ytrue, q=0.5)
lstm_qrisk_09 = Metrics.q_risk(lstm_ypred, lstm_ytrue, q=0.9)
print(f"LSTM q-risk .5 - .9: {lstm_qrisk_05} - {lstm_qrisk_09}")

# TFT
tft_qrisk_05 = Metrics.q_risk(tft_ypred, tft_ytrue, q=0.5)
tft_qrisk_09 = Metrics.q_risk(tft_ypred, tft_ytrue, q=0.9)
print(f"TFT q-risk .5 - .9: {tft_qrisk_05} - {tft_qrisk_09}")

## Statistical tests

In [None]:
print("*************MAE*****************")
for i in [lstm_mae, np.array(var_mae), naive_mae]:
  print(stats.wilcoxon(
      tft_mae,
      i,
      alternative="less",
      )[1] < 0.05/4) # with Bonferroni correction
print("*************RMSE***************")
for i in [lstm_rmse, np.array(var_rmse), naive_rmse]:
  print(stats.wilcoxon(
      tft_rmse,
      i,
      alternative="less",
      )[1] < 0.05/4) # with Bonferroni correction
print("*************MDA**************")
for i in [np.unique(lstm_mda.mean(axis=1), return_counts=True)[1],
          np.unique(np.concatenate(var_mda).mean(axis=1), return_counts=True)[1]]:
  print(stats.chisquare(
      i,
      np.unique(tft_mda.mean(axis=1), return_counts=True)[1]
      )[1] < 0.05/3 # Bonferroni correction
  )

# Plots

In [None]:
comp_df = test_dataloader.dataset.x_to_index(x).join(df[["sensor_id", "measurement_year", "orchard_id"]].drop_duplicates().set_index(["sensor_id", "measurement_year"]), on=["sensor_id", "measurement_year"], how="left")

comp_df["lstm_mae"] = lstm_mae
comp_df["var_mae"] = var_mae
comp_df["naive_mae"] = naive_mae
comp_df["tft_mae"] = tft_mae

In [None]:
# MAE plot per orchard
width = 0.6
x_lab = [f"orchard_{i}" for i in comp_df["orchard_id"].unique()]
x_idx = np.arange(0, len(x_lab)*2, 2)

fig, ax = plt.subplots(figsize=(5.42,4))
var_box = ax.boxplot([comp_df.groupby("orchard_id")["var_mae"].get_group(int(i)) for i in x_idx/2], 
                      positions=x_idx - width*3/4, 
                      widths=0.25,
                      notch=True,
                      sym="x",
                      medianprops={"c":"k"},
                      patch_artist=True,
                     flierprops={'markersize':5}
                      )
for box in var_box['boxes']:
    box.set(hatch = 'xxx', fill=False)  

naive_box = ax.boxplot([comp_df.groupby("orchard_id")["naive_mae"].get_group(int(i)) for i in x_idx/2], 
                       positions=x_idx - width/4, 
                       widths=0.25,
                      notch=True,
                      sym="+",
                      medianprops={"c":"k"},
                      patch_artist=True,
                       flierprops={'markersize':5}
                      )
for box in naive_box['boxes']:
    box.set(hatch = '+++', fill=False)  


lstm_box = ax.boxplot([comp_df.groupby("orchard_id")["lstm_mae"].get_group(int(i)) for i in x_idx/2], 
                      positions=x_idx + width/4, 
                      widths=0.25,
                      notch=True,
                      sym="o",
                      medianprops={"c":"k"},
                      patch_artist=True,
                      flierprops={'markersize':5}
                      )
for box in lstm_box['boxes']:
    box.set(hatch = '...', fill=False) 

tft_box = ax.boxplot([comp_df.groupby("orchard_id")["tft_mae"].get_group(int(i)) for i in x_idx/2], 
                     positions=x_idx + width*3/4, 
                     widths=0.25,
                      notch=True,
                      sym="^",
                      medianprops={"c":"k"},
                      patch_artist=True,
                     flierprops={'markersize':5}
                      )
for box in tft_box['boxes']:
    box.set(hatch = '///', fill=False) 

ax.set_xticks(x_idx)
ax.set_xticklabels(x_lab)
ax.set_ylabel(r"$MAE$")
ax.set_xlabel("Orchard ID")
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.legend([var_box["boxes"][0],
           naive_box["boxes"][0],
           lstm_box["boxes"][0],
           tft_box["boxes"][0],
           ], ["VAR", "Naïve", "LSTM", "TFT"])

plt.savefig(f'mae_plot.eps', dpi=300, format="eps", bbox_inches='tight')

In [None]:
interpretation = best_tft.interpret_output(raw_pred)

In [None]:
# single plot
i = 317
y_hat = DataLoader.descale(target_descaler, raw_pred["prediction"][i,:,2].reshape(-1,1))
y_true_plot = DataLoader.descale(target_descaler, x["decoder_target"].numpy()[i].reshape(-1,1))
y_input = DataLoader.descale(target_descaler, x["encoder_target"].numpy()[i].reshape(-1,1))

idx=0
fig, ax1 = plt.subplots(figsize=(5.42,4))
ax1.plot(range(idx-7,idx), y_input, c="k")
obs = ax1.plot(range(idx,idx+5), y_true_plot, c="k", label="observed")
pred = ax1.plot(range(idx,idx+5), y_hat, c="red", label="predicted")

ax1.plot(range(idx,idx+5), DataLoader.descale(target_descaler, sm_quantiles["quantile_0"][i].reshape(-1,1)), alpha=0.5, linestyle=":", c="red")
iqr = ax1.plot(range(idx,idx+5), DataLoader.descale(target_descaler, sm_quantiles["quantile_1"][i].reshape(-1,1)), alpha=0.5, linestyle="--", c="red", label="IQR")
ax1.plot(range(idx,idx+5), DataLoader.descale(target_descaler, sm_quantiles["quantile_3"][i].reshape(-1,1)), alpha=0.5, linestyle="--", c="red")
q = ax1.plot(range(idx,idx+5), DataLoader.descale(target_descaler, sm_quantiles["quantile_4"][i].reshape(-1,1)), alpha=0.5, linestyle=":", c="red", label=".1:.9-quantile")
#plt.plot(x["decoder_target"][i], c="k")

ax2 = ax1.twinx()
att = ax2.plot(range(idx-7, idx), interpretation["attention"][i], c="k", alpha=0.2, label="attention")

lns = obs+pred+att+iqr+q
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, bbox_to_anchor=(1.5, 1))
ax1.spines["top"].set_visible(False)
ax2.spines["top"].set_visible(False)

ax1.set_ylabel(r"$\Delta_1$(Soil water potential ($\psi_{soil}$))")
ax2.set_ylabel("Attention")
plt.title("TFT 5-day Forecast")
ax1.set_xticks(range(-7,-7+12))
ax1.set_xticklabels(list(range(-7,0))+list(range(0,5)))
ax1.set_xlabel("Time index")
plt.savefig('val_forecast.svg', dpi=300, format="svg", bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(5.42,4))
ax.plot(range(-7,0),
        interpretation["attention"].mean(axis=0),
        c="k",
        marker="^",
        markersize=5,
        label="mean",
        linestyle=":"
        )
ax.boxplot(
        [interpretation["attention"][:,i] for i in range(7)],
        positions=range(-7,0),
        notch=True,
        medianprops={"c":"k"},
        flierprops={"markersize": 5}
        )

ax.set_ylabel(r"Attention")
ax.set_xlabel("Time index")
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.legend()

plt.savefig(f'attention_plot.eps', dpi=300, format="eps", bbox_inches='tight')

In [None]:
interpretation = best_tft.interpret_output(raw_pred, reduction="sum")

In [None]:
values = interpretation["static_variables"].numpy() / np.sum(interpretation["static_variables"].numpy())
fig, ax = plt.subplots(nrows=3, ncols=1, figsize=(5.42, 3*len(values) * 0.25 + 2), constrained_layout=True)
order = np.argsort(values)
labels = best_tft.static_variables
ax[0].barh(np.arange(len(values)), values[order] * 100, tick_label=np.asarray(labels)[order], fill=False, hatch="/////")
ax[0].set_title("Static var importance")
ax[0].set_xlabel("Importance in %")
y_names = ["Orchard name", "Measurement year", "Sensor depth", "Irrigation treatment", "Soil texture", "Pruning treatment", "Soil water potential center", "Soil water potential scale"]
ax[0].set_yticklabels(reversed(y_names))

values = interpretation["encoder_variables"] / np.sum(interpretation["encoder_variables"].numpy())
order = np.argsort(values)
labels = best_tft.encoder_variables
ax[1].barh(np.arange(len(values)), values[order] * 100, tick_label=np.asarray(labels)[order], fill=False, hatch=".....")
ax[1].set_title("Historical var importance")
ax[1].set_xlabel("Importance in %")
y_names = ["Relative time index", "ETo", "Irrigation amount", r"$\Delta_1$(soil water potential)", "Precipitation", "Soil temperature", "Measurement month"]
ax[1].set_yticklabels(y_names)

values = interpretation["decoder_variables"] / np.sum(interpretation["decoder_variables"].numpy())
order = np.argsort(values)
labels = best_tft.decoder_variables
ax[2].barh(np.arange(len(values)), values[order] * 100, tick_label=np.asarray(labels)[order], fill=False, hatch="xxxxx")
ax[2].set_title("Known var importance")
ax[2].set_xlabel("Importance in %")
y_names = ["Relative time index", "ETo", "Measurement month", "Precipitation"]
ax[2].set_yticklabels(y_names)

plt.savefig(f'var_imp.eps', dpi=300, format="eps")