In [21]:
import sales_forecasting.config as config
import sales_forecasting.models.train as train_module

In [22]:
import sales_forecasting.models.model as model

print(len(list((model.GATGCNLSTM(4, 16).parameters()))))
print(len(list((model.GCNLSTMBaseline(4, 16).parameters()))))
print(len(list((model.LSTMBaseline(4, 16).parameters()))))

31
25
6


In [23]:
model_dict = {
    "GATGCNLSTM": model.GATGCNLSTM(4, 4),
    "GCNLSTMBaseline": model.GCNLSTMBaseline(4, 4),
    "LSTMBaseline": model.LSTMBaseline(4, 4),
}

for name, m in model_dict.items():
    total_params = sum(p.numel() for p in m.parameters())
    print(f"{name} has {total_params} parameters")

GATGCNLSTM has 241 parameters
GCNLSTMBaseline has 193 parameters
LSTMBaseline has 165 parameters


In [24]:
for model_name, m in model_dict.items():
    print(f"\n{model_name} parameter breakdown:")
    for name, param in m.named_parameters():
        print(f"{name}: {param.numel()} parameters")


GATGCNLSTM parameter breakdown:
gat.att: 4 parameters
gat.bias: 4 parameters
gat.lin_l.weight: 16 parameters
gat.lin_l.bias: 4 parameters
gat.lin_r.weight: 16 parameters
gat.lin_r.bias: 4 parameters
gcnlstm.w_c_i: 4 parameters
gcnlstm.b_i: 4 parameters
gcnlstm.w_c_f: 4 parameters
gcnlstm.b_f: 4 parameters
gcnlstm.b_c: 4 parameters
gcnlstm.w_c_o: 4 parameters
gcnlstm.b_o: 4 parameters
gcnlstm.conv_x_i.bias: 4 parameters
gcnlstm.conv_x_i.lins.0.weight: 16 parameters
gcnlstm.conv_h_i.bias: 4 parameters
gcnlstm.conv_h_i.lins.0.weight: 16 parameters
gcnlstm.conv_x_f.bias: 4 parameters
gcnlstm.conv_x_f.lins.0.weight: 16 parameters
gcnlstm.conv_h_f.bias: 4 parameters
gcnlstm.conv_h_f.lins.0.weight: 16 parameters
gcnlstm.conv_x_c.bias: 4 parameters
gcnlstm.conv_x_c.lins.0.weight: 16 parameters
gcnlstm.conv_h_c.bias: 4 parameters
gcnlstm.conv_h_c.lins.0.weight: 16 parameters
gcnlstm.conv_x_o.bias: 4 parameters
gcnlstm.conv_x_o.lins.0.weight: 16 parameters
gcnlstm.conv_h_o.bias: 4 parameters
gc

In [25]:
import torch.nn as nn
import torch_geometric.nn
from torch_geometric_temporal.nn.recurrent import GConvLSTM


class LSTMBaseline(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)  # predict scalar target

    def forward(self, x):
        # x: [B, window_size, F]
        out, _ = self.lstm(x)  # out: [B, window_size, hidden_size]
        out = out[:, -1, :]  # take output of last timestep: [B, hidden_size]
        out = self.fc(out)  # [B, 1]
        return out.squeeze(-1)  # [B]


class GCNLSTMBaseline(nn.Module):
    def __init__(self, input_size, hidden_size, K=1):
        super().__init__()
        self.gcnlstm = GConvLSTM(in_channels=input_size, out_channels=hidden_size, K=K)
        self.linear = nn.Linear(hidden_size, 1)

    def forward(self, x, edge_index, edge_weight, h=None, c=None):
        # x: [window_size, N, F]
        # edge_index: [window_size, 2, E]
        # edge_weight: [window_size, E]
        # h: [N, hidden_size]
        # c: [N, hidden_size]

        window_size, N, F = x.shape
        for t in range(window_size):
            # Pass previous hidden states (if any)
            h, c = self.gcnlstm(
                x[t, :, :], edge_index[t, :, :], edge_weight[t, :], h, c
            )
        out = self.linear(h)
        return out.squeeze(-1), (h, c)


class GATGCNLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, K=1):
        super().__init__()
        self.gat = torch_geometric.nn.conv.GATv2Conv(
            in_channels=input_size,
            out_channels=input_size,
        )
        self.gcnlstm = GConvLSTM(in_channels=input_size, out_channels=hidden_size, K=K)
        self.linear = nn.Linear(hidden_size, 1)

    def forward(self, x, edge_index, edge_weight, h=None, c=None):
        # x: [window_size, N, F]
        # edge_index: [window_size, 2, E]
        # edge_weight: [window_size, E]
        # h: [N, hidden_size]
        # c: [N, hidden_size]

        window_size, N, F = x.shape
        for t in range(window_size):
            # Apply GAT to get edge weights
            (_, (e_index, attention_weights)) = self.gat(
                x=x[t, :, :],
                edge_index=edge_index[t, :, :],
                return_attention_weights=True,
            )
            # Pass previous hidden states (if any)
            h, c = self.gcnlstm(
                x[t, :, :], e_index, attention_weights.squeeze(-1), h, c
            )
        out = self.linear(h)
        return out.squeeze(-1), (h, c)

In [26]:
# Replace LSTM with GRU in the baseline model
class GRUBaseline(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1):
        super().__init__()
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)  # predict scalar target

    def forward(self, x):
        # x: [B, window_size, F]
        out, _ = self.gru(x)  # out: [B, window_size, hidden_size]
        out = out[:, -1, :]  # take output of last timestep: [B, hidden_size]
        out = self.fc(out)  # [B, 1]
        return out.squeeze(-1)  # [B]


# Update existing instances to use the new class
m = GRUBaseline(4, 4)
model_dict["GRUBaseline"] = m

# quick sanity check
print(m)
print(f"GRUBaseline now has {sum(p.numel() for p in m.parameters())} parameters")

GRUBaseline(
  (gru): GRU(4, 4, batch_first=True)
  (fc): Linear(in_features=4, out_features=1, bias=True)
)
GRUBaseline now has 125 parameters


In [27]:
from torch_geometric_temporal.nn.recurrent import GConvGRU


class GCNGRUBaseline(nn.Module):
    def __init__(self, input_size, hidden_size, K=1):
        super().__init__()
        self.gconvgru = GConvGRU(in_channels=input_size, out_channels=hidden_size, K=K)
        self.linear = nn.Linear(hidden_size, 1)

    def forward(self, x, edge_index, edge_weight, h=None):
        # x: [window_size, N, F]
        window_size, N, F = x.shape
        for t in range(window_size):
            h = self.gconvgru(x[t, :, :], edge_index[t, :, :], edge_weight[t, :], h)
        out = self.linear(h)
        return out.squeeze(-1), h


class GATGCNGRU(nn.Module):
    def __init__(self, input_size, hidden_size, K=1):
        super().__init__()
        self.gat = torch_geometric.nn.conv.GATv2Conv(
            in_channels=input_size,
            out_channels=input_size,
        )
        self.gconvgru = GConvGRU(in_channels=input_size, out_channels=hidden_size, K=K)
        self.linear = nn.Linear(hidden_size, 1)

    def forward(self, x, edge_index, edge_weight, h=None):
        # x: [window_size, N, F]
        window_size, N, F = x.shape
        for t in range(window_size):
            (_, (e_index, attention_weights)) = self.gat(
                x=x[t, :, :],
                edge_index=edge_index[t, :, :],
                return_attention_weights=True,
            )
            h = self.gconvgru(x[t, :, :], e_index, attention_weights.squeeze(-1), h)
        out = self.linear(h)
        return out.squeeze(-1), h


# Replace the existing LSTM-based models in the notebook's model_dict
model_dict["GCNLSTMBaseline"] = GCNGRUBaseline(4, 4)
model_dict["GATGCNLSTM"] = GATGCNGRU(4, 4)

# quick sanity checks
for name in ("GCNLSTMBaseline", "GATGCNLSTM"):
    m = model_dict[name]
    print(f"{name}: {m}")
    print(f"{name} has {sum(p.numel() for p in m.parameters())} parameters")

GCNLSTMBaseline: GCNGRUBaseline(
  (gconvgru): GConvGRU(
    (conv_x_z): ChebConv(4, 4, K=1, normalization=sym)
    (conv_h_z): ChebConv(4, 4, K=1, normalization=sym)
    (conv_x_r): ChebConv(4, 4, K=1, normalization=sym)
    (conv_h_r): ChebConv(4, 4, K=1, normalization=sym)
    (conv_x_h): ChebConv(4, 4, K=1, normalization=sym)
    (conv_h_h): ChebConv(4, 4, K=1, normalization=sym)
  )
  (linear): Linear(in_features=4, out_features=1, bias=True)
)
GCNLSTMBaseline has 125 parameters
GATGCNLSTM: GATGCNGRU(
  (gat): GATv2Conv(4, 4, heads=1)
  (gconvgru): GConvGRU(
    (conv_x_z): ChebConv(4, 4, K=1, normalization=sym)
    (conv_h_z): ChebConv(4, 4, K=1, normalization=sym)
    (conv_x_r): ChebConv(4, 4, K=1, normalization=sym)
    (conv_h_r): ChebConv(4, 4, K=1, normalization=sym)
    (conv_x_h): ChebConv(4, 4, K=1, normalization=sym)
    (conv_h_h): ChebConv(4, 4, K=1, normalization=sym)
  )
  (linear): Linear(in_features=4, out_features=1, bias=True)
)
GATGCNLSTM has 173 parameters


In [29]:
# instantiate all 6 model classes and print per-parameter breakdown
all_models = {
    "LSTMBaseline": LSTMBaseline(4, 4),
    "GRUBaseline": GRUBaseline(4, 4),
    "GCNLSTMBaseline": GCNLSTMBaseline(4, 4),
    "GCNGRUBaseline": GCNGRUBaseline(4, 4),
    "GATGCNLSTM": GATGCNLSTM(4, 4),
    "GATGCNGRU": GATGCNGRU(4, 4),
}

for name, m in all_models.items():
    total = sum(p.numel() for p in m.parameters())
    print(f"\n{name} total parameters: {total}")
    for pname, p in m.named_parameters():
        print(f"  {pname}: {p.numel()}")


LSTMBaseline total parameters: 165
  lstm.weight_ih_l0: 64
  lstm.weight_hh_l0: 64
  lstm.bias_ih_l0: 16
  lstm.bias_hh_l0: 16
  fc.weight: 4
  fc.bias: 1

GRUBaseline total parameters: 125
  gru.weight_ih_l0: 48
  gru.weight_hh_l0: 48
  gru.bias_ih_l0: 12
  gru.bias_hh_l0: 12
  fc.weight: 4
  fc.bias: 1

GCNLSTMBaseline total parameters: 193
  gcnlstm.w_c_i: 4
  gcnlstm.b_i: 4
  gcnlstm.w_c_f: 4
  gcnlstm.b_f: 4
  gcnlstm.b_c: 4
  gcnlstm.w_c_o: 4
  gcnlstm.b_o: 4
  gcnlstm.conv_x_i.bias: 4
  gcnlstm.conv_x_i.lins.0.weight: 16
  gcnlstm.conv_h_i.bias: 4
  gcnlstm.conv_h_i.lins.0.weight: 16
  gcnlstm.conv_x_f.bias: 4
  gcnlstm.conv_x_f.lins.0.weight: 16
  gcnlstm.conv_h_f.bias: 4
  gcnlstm.conv_h_f.lins.0.weight: 16
  gcnlstm.conv_x_c.bias: 4
  gcnlstm.conv_x_c.lins.0.weight: 16
  gcnlstm.conv_h_c.bias: 4
  gcnlstm.conv_h_c.lins.0.weight: 16
  gcnlstm.conv_x_o.bias: 4
  gcnlstm.conv_x_o.lins.0.weight: 16
  gcnlstm.conv_h_o.bias: 4
  gcnlstm.conv_h_o.lins.0.weight: 16
  linear.weight: 

In [None]:
from dataclasses import replace

base = config.TrainingConfig()
tc = replace(
    base,
    lr=0.0055,
    batch_size=8,
    window_size=3,
    hidden_size=32,
    K=1,
)

In [None]:
from dataclasses import asdict

import mlflow

from sales_forecasting.utils.experiments import start_run

run_name = "testing_with_plant_edges"
with start_run(run_name=run_name):
    mlflow.log_params(asdict(tc))
    train_module.run_experiment(tc)

Epoch 1/100 | Train Loss: 1.1464 | Val Loss: 1.3608 | Val RMSE: 1.1665
 --> Best model saved at epoch 1
Epoch 2/100 | Train Loss: 1.1183 | Val Loss: 1.3470 | Val RMSE: 1.1606
 --> Best model saved at epoch 2
Epoch 3/100 | Train Loss: 1.0922 | Val Loss: 1.3314 | Val RMSE: 1.1538
 --> Best model saved at epoch 3
Epoch 4/100 | Train Loss: 1.0566 | Val Loss: 1.3625 | Val RMSE: 1.1673
Epoch 5/100 | Train Loss: 1.0203 | Val Loss: 1.4432 | Val RMSE: 1.2013
Epoch 6/100 | Train Loss: 0.9863 | Val Loss: 1.5235 | Val RMSE: 1.2343
Epoch 7/100 | Train Loss: 0.9633 | Val Loss: 1.5984 | Val RMSE: 1.2643
Epoch 8/100 | Train Loss: 0.9168 | Val Loss: 1.5013 | Val RMSE: 1.2253
Epoch 9/100 | Train Loss: 0.8397 | Val Loss: 1.5316 | Val RMSE: 1.2376
Epoch 10/100 | Train Loss: 0.7936 | Val Loss: 1.5539 | Val RMSE: 1.2465
Epoch 11/100 | Train Loss: 0.7386 | Val Loss: 1.5407 | Val RMSE: 1.2413
Epoch 12/100 | Train Loss: 0.7004 | Val Loss: 1.5407 | Val RMSE: 1.2412
Epoch 13/100 | Train Loss: 0.6574 | Val Loss: 



Epoch 1/100 | Train Loss: 1.0961 | Val Loss: 0.8603 | Val RMSE: 0.9275
 --> Best model saved at epoch 1
Epoch 2/100 | Train Loss: 1.0545 | Val Loss: 0.8278 | Val RMSE: 0.9098
 --> Best model saved at epoch 2
Epoch 3/100 | Train Loss: 1.0245 | Val Loss: 0.8192 | Val RMSE: 0.9051
 --> Best model saved at epoch 3
Epoch 4/100 | Train Loss: 1.0054 | Val Loss: 0.8762 | Val RMSE: 0.9361
Epoch 5/100 | Train Loss: 0.9649 | Val Loss: 0.8140 | Val RMSE: 0.9022
 --> Best model saved at epoch 5
Epoch 6/100 | Train Loss: 0.9181 | Val Loss: 0.8638 | Val RMSE: 0.9294
Epoch 7/100 | Train Loss: 0.8581 | Val Loss: 0.7898 | Val RMSE: 0.8887
 --> Best model saved at epoch 7
Epoch 8/100 | Train Loss: 0.8150 | Val Loss: 0.8305 | Val RMSE: 0.9113
Epoch 9/100 | Train Loss: 0.7575 | Val Loss: 0.8229 | Val RMSE: 0.9071
Epoch 10/100 | Train Loss: 0.7210 | Val Loss: 0.8293 | Val RMSE: 0.9107
Epoch 11/100 | Train Loss: 0.6977 | Val Loss: 0.8817 | Val RMSE: 0.9390
Epoch 12/100 | Train Loss: 0.6452 | Val Loss: 0.9215



Epoch 17/100 | Train Loss: 0.5717 | Val Loss: 0.9347 | Val RMSE: 0.9668
Early stopping triggered.




Epoch 1/100 | Train Loss: 1.1255 | Val Loss: 0.9809 | Val RMSE: 0.9904
 --> Best model saved at epoch 1
Epoch 2/100 | Train Loss: 1.0853 | Val Loss: 1.0009 | Val RMSE: 1.0005
Epoch 3/100 | Train Loss: 1.0830 | Val Loss: 0.9996 | Val RMSE: 0.9998
Epoch 4/100 | Train Loss: 1.0545 | Val Loss: 0.9941 | Val RMSE: 0.9970
Epoch 5/100 | Train Loss: 1.0425 | Val Loss: 0.9786 | Val RMSE: 0.9892
 --> Best model saved at epoch 5
Epoch 6/100 | Train Loss: 1.0151 | Val Loss: 1.0034 | Val RMSE: 1.0017
Epoch 7/100 | Train Loss: 0.9766 | Val Loss: 0.9848 | Val RMSE: 0.9924
Epoch 8/100 | Train Loss: 0.9645 | Val Loss: 0.9652 | Val RMSE: 0.9825
 --> Best model saved at epoch 8
Epoch 9/100 | Train Loss: 0.8874 | Val Loss: 0.9807 | Val RMSE: 0.9903
Epoch 10/100 | Train Loss: 0.8342 | Val Loss: 1.0286 | Val RMSE: 1.0142
Epoch 11/100 | Train Loss: 0.7830 | Val Loss: 0.9941 | Val RMSE: 0.9971
Epoch 12/100 | Train Loss: 0.7636 | Val Loss: 1.0750 | Val RMSE: 1.0368
Epoch 13/100 | Train Loss: 0.7362 | Val Loss: 



Epoch 18/100 | Train Loss: 0.5875 | Val Loss: 1.0446 | Val RMSE: 1.0220
Early stopping triggered.




Epoch 1/100 | Train Loss: 1.1308 | Val Loss: 0.5662 | Val RMSE: 0.7525
 --> Best model saved at epoch 1
Epoch 2/100 | Train Loss: 1.1009 | Val Loss: 0.5486 | Val RMSE: 0.7407
 --> Best model saved at epoch 2
Epoch 3/100 | Train Loss: 1.0861 | Val Loss: 0.5357 | Val RMSE: 0.7319
 --> Best model saved at epoch 3
Epoch 4/100 | Train Loss: 1.0882 | Val Loss: 0.5565 | Val RMSE: 0.7460
Epoch 5/100 | Train Loss: 1.0864 | Val Loss: 0.5308 | Val RMSE: 0.7286
 --> Best model saved at epoch 5
Epoch 6/100 | Train Loss: 1.0578 | Val Loss: 0.5215 | Val RMSE: 0.7221
 --> Best model saved at epoch 6
Epoch 7/100 | Train Loss: 1.0334 | Val Loss: 0.4896 | Val RMSE: 0.6997
 --> Best model saved at epoch 7
Epoch 8/100 | Train Loss: 0.9906 | Val Loss: 0.4102 | Val RMSE: 0.6405
 --> Best model saved at epoch 8
Epoch 9/100 | Train Loss: 0.9666 | Val Loss: 0.3647 | Val RMSE: 0.6039
 --> Best model saved at epoch 9
Epoch 10/100 | Train Loss: 0.9030 | Val Loss: 0.3435 | Val RMSE: 0.5861
 --> Best model saved at 



Epoch 25/100 | Train Loss: 0.5915 | Val Loss: 0.4033 | Val RMSE: 0.6350
Early stopping triggered.




Epoch 1/100 | Train Loss: 1.1139 | Val Loss: 0.8225 | Val RMSE: 0.9069
 --> Best model saved at epoch 1
Epoch 2/100 | Train Loss: 1.0817 | Val Loss: 0.8244 | Val RMSE: 0.9080
Epoch 3/100 | Train Loss: 1.0784 | Val Loss: 0.8362 | Val RMSE: 0.9144
Epoch 4/100 | Train Loss: 1.0734 | Val Loss: 0.8153 | Val RMSE: 0.9029
 --> Best model saved at epoch 4
Epoch 5/100 | Train Loss: 1.0705 | Val Loss: 0.7979 | Val RMSE: 0.8932
 --> Best model saved at epoch 5
Epoch 6/100 | Train Loss: 1.0225 | Val Loss: 0.7507 | Val RMSE: 0.8664
 --> Best model saved at epoch 6
Epoch 7/100 | Train Loss: 0.9547 | Val Loss: 0.6879 | Val RMSE: 0.8294
 --> Best model saved at epoch 7
Epoch 8/100 | Train Loss: 0.9261 | Val Loss: 0.6855 | Val RMSE: 0.8279
 --> Best model saved at epoch 8
Epoch 9/100 | Train Loss: 0.8712 | Val Loss: 0.6846 | Val RMSE: 0.8274
 --> Best model saved at epoch 9
Epoch 10/100 | Train Loss: 0.8368 | Val Loss: 0.6698 | Val RMSE: 0.8184
 --> Best model saved at epoch 10
Epoch 11/100 | Train Los



Epoch 20/100 | Train Loss: 0.6005 | Val Loss: 0.7334 | Val RMSE: 0.8564
Early stopping triggered.




Mean best validation rmse: 0.8817 ± 0.2171
