In [35]:
### IMPORTS ###
import warnings
warnings.filterwarnings("ignore")
import copy
from pathlib import Path
import warnings

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
import numpy as np
import pandas as pd
import torch

from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import (
    optimize_hyperparameters,
)
import pandas as pd
# Norwegian special days/holidays using the holidays package
import holidays

In [36]:
################## CLEANING THE PURCHASE ORDERS DATA ##############

orders = pd.read_csv("../data/kernel/purchase_orders.csv")

# Time is in GMT+2 which is Norway time
# Make delivery_date, created_date_time and modified_date_time to GMT +2
orders['delivery_date'] = pd.to_datetime(orders['delivery_date'], utc=True).dt.tz_convert('Etc/GMT-2')
orders['created_date_time'] = pd.to_datetime(orders['created_date_time'], utc=True).dt.tz_convert('Etc/GMT-2')
orders['modified_date_time'] = pd.to_datetime(orders['modified_date_time'], utc=True).dt.tz_convert('Etc/GMT-2')


################# CLEANING THE RECEIVALS DATA ########################
receivals = pd.read_csv("../data/kernel/receivals.csv")

# Make the date_arrival to GMT +2
receivals['date_arrival'] = pd.to_datetime(receivals['date_arrival'], utc=True).dt.tz_convert('Etc/GMT-2')


############### MERGE ORDERS AND RECEIVALS DATA ###########################
# --- Merge orders and receivals WITHOUT aggregation ---
orders_with_receivals = orders.merge(
    receivals,
    on=["purchase_order_id", "purchase_order_item_no"],
    how="left",
    suffixes=('_order', '_receival')
)

# --- Fill missing values for orders with no receivals ---
orders_with_receivals["net_weight"] = orders_with_receivals["net_weight"].fillna(0)
orders_with_receivals["date_arrival"] = pd.to_datetime(orders_with_receivals["date_arrival"])


# Make the orders with PUND in KGs, and change quantity accordingly
# 1 PUND = 0,45359237 kilogram
orders_with_receivals.loc[orders_with_receivals['unit'] == 'PUND', 'quantity'] = orders_with_receivals.loc[orders_with_receivals['unit'] == 'PUND', 'quantity'] * 0.45359237
orders_with_receivals.loc[orders_with_receivals['unit'] == 'PUND', 'net_weight'] = orders_with_receivals.loc[orders_with_receivals['unit'] == 'PUND', 'net_weight'] * 0.45359237
# Change the unit to KG too: orders_with_receivals.loc[orders_with_receivals['unit'] == 'PUND', 'unit'] = 'KG'
# Drop unit_id and unit columns
orders_with_receivals = orders_with_receivals.drop(columns=['unit_id', 'unit'])

# --- Derived features ---
orders_with_receivals["fill_fraction"] = orders_with_receivals["net_weight"] / orders_with_receivals["quantity"]
orders_with_receivals["lead_time"] = (
    orders_with_receivals["date_arrival"] - orders_with_receivals["delivery_date"]
).dt.days
orders_with_receivals["lead_time"] = orders_with_receivals["lead_time"].fillna(0)


####################### SELECT RELEVANT COLUMNS FROM THE MERGED DATAFRAME ##################################
orders_with_receivals = orders_with_receivals[orders_with_receivals['rm_id'].notnull() & orders_with_receivals['date_arrival'].notnull()]
# date_arrival = actual date of receival, delivery_date = expected date of receival
# lead_time = date_arrival - delivery_date
# quantity  = quantity, net_weight = weight in kg (the actual target per day etc)
selected = orders_with_receivals[["rm_id", "date_arrival", "net_weight", "supplier_id", "delivery_date", "product_id_receival", "quantity", "lead_time"]]
# Filter out the selected rows where rm_id is null or date_arrival is null
selected = selected[selected['rm_id'].notnull() & selected['date_arrival'].notnull()]



In [37]:
##################### CREATING TIME_IDX AND AGGREGATING TO DAILY LEVEL AND FILLING GAPS WITH 0 NET_WEIGHT RECEIVALS ############################
# make a copy and normalize date_arrival to date-only (drop time) so grouping is by year-month-day
df_agg = selected.copy()
# ensure date_arrival is a datetime and floor to day (sets time to 00:00:00)
df_agg['date_arrival'] = df_agg['date_arrival'].dt.floor('D')
# Remove timezone info if present
df_agg['date_arrival'] = df_agg['date_arrival'].dt.tz_localize(None)

df_agg = df_agg.groupby(['rm_id', 'date_arrival']).agg({
    'net_weight': 'sum',
    'quantity': 'sum',
}).reset_index()

# Add time_idx based on days since each rm_id's minimum date
df_agg = df_agg.sort_values(['rm_id', 'date_arrival'])
df_agg['local_time_idx'] = (df_agg['date_arrival'] - df_agg.groupby('rm_id')['date_arrival'].transform('min')).dt.days

# Fill gaps from each rm_id's min date to 2024-12-31 with 0 net_weight entries
end_date = pd.Timestamp('2024-12-31')
all_filled = []

for rm_id, group in df_agg.groupby('rm_id'):
    min_date = group['date_arrival'].min()
    max_idx = (end_date - min_date).days
    
    full_range = pd.DataFrame({
        'local_time_idx': range(0, max_idx + 1)
    })
    full_range['rm_id'] = rm_id
    full_range['date_arrival'] = min_date + pd.to_timedelta(full_range['local_time_idx'], unit='D')
    
    merged = pd.merge(full_range, group, on=['rm_id', 'local_time_idx', 'date_arrival'], how='left')
    merged['net_weight'] = merged['net_weight'].fillna(0)
    merged['quantity'] = merged['quantity'].fillna(0)
  
    all_filled.append(merged)

df_agg = pd.concat(all_filled, ignore_index=True)
selected_with_local_time = df_agg

In [38]:
### CUMULATIVE TESTING#################

# For rm_ids that start in 2024 set their local_time_idx to start from 0 and fill 2024-01-01 to their start date with 0 net_weight
# For rm_ids that start in 2024, backfill from 2024-01-01 and renumber local_time_idx from 0
rm_ids_to_fix = []
for rm_id, group in selected_with_local_time.groupby('rm_id'):
    min_date = group['date_arrival'].min()
    if min_date.year == 2024:
        rm_ids_to_fix.append((rm_id, min_date))

# Process each rm_id that needs fixing
for rm_id, original_min_date in rm_ids_to_fix:
    # Calculate days to add from 2024-01-01 to original_min_date
    days_to_add = (original_min_date - pd.Timestamp('2024-01-01')).days
    
    # Create additional rows from 2024-01-01 to day before original_min_date
    if days_to_add > 0:
        additional_rows = pd.DataFrame({
            'rm_id': rm_id,
            'date_arrival': [pd.Timestamp('2024-01-01') + pd.Timedelta(days=i) for i in range(days_to_add)],
            'net_weight': 0.0,
            'quantity': 0.0,
            'local_time_idx': range(days_to_add)  # 0, 1, 2, ... up to days_to_add-1
        })
        
        # Update local_time_idx for existing rows of this rm_id
        mask = selected_with_local_time['rm_id'] == rm_id
        selected_with_local_time.loc[mask, 'local_time_idx'] += days_to_add
        
        # Append the new rows
        selected_with_local_time = pd.concat([selected_with_local_time, additional_rows], ignore_index=True)

# Sort by rm_id and date_arrival to maintain order
selected_with_local_time = selected_with_local_time.sort_values(['rm_id', 'date_arrival']).reset_index(drop=True)


# For every year and rm_id I want to make net_weight cumulative. So day1 = day1, day2 = day1+day2, day3 = day1+day2+day3 etc.
# But it should reset at the start of each year for each rm_id
selected_with_local_time['year'] = selected_with_local_time['date_arrival'].dt.year
selected_with_local_time['net_weight_cum'] = selected_with_local_time.groupby(['rm_id', 'year'])['net_weight'].cumsum()

In [39]:
######################### ADD ADDITIONAL FEATURES ##################################
# Add additional features
selected_with_local_time["month"] = selected_with_local_time["date_arrival"].dt.month.astype(str).astype("category")
selected_with_local_time["year"] = selected_with_local_time["date_arrival"].dt.year.astype(str).astype("category")
selected_with_local_time["day_of_week"] = selected_with_local_time["date_arrival"].dt.dayofweek.astype(str).astype("category")  # 0=Monday, 6=Sunday
selected_with_local_time["day_of_month"] = selected_with_local_time["date_arrival"].dt.day.astype(str).astype("category")
selected_with_local_time["week_of_year"] = selected_with_local_time["date_arrival"].dt.isocalendar().week.astype(str).astype("category")
selected_with_local_time["quarter"] = selected_with_local_time["date_arrival"].dt.quarter.astype(str).astype("category")
selected_with_local_time["is_weekend"] = (selected_with_local_time["date_arrival"].dt.dayofweek >= 5).astype(int)
selected_with_local_time["is_month_start"] = selected_with_local_time["date_arrival"].dt.is_month_start.astype(int)
selected_with_local_time["is_month_end"] = selected_with_local_time["date_arrival"].dt.is_month_end.astype(int)
selected_with_local_time["log_weight"] = np.log1p(selected_with_local_time["net_weight"])

print("Adding lag and rolling features...")
# Lag features - CRITICAL for time series forecasting
for lag in [7, 14, 28, 56, 91]:
    selected_with_local_time[f'net_weight_lag_{lag}'] = selected_with_local_time.groupby('rm_id')['net_weight'].shift(lag)
    selected_with_local_time[f'log_weight_lag_{lag}'] = selected_with_local_time.groupby('rm_id')['log_weight'].shift(lag)

# Rolling window statistics - capture trends
for window in [7, 28, 91]:
    selected_with_local_time[f'net_weight_rolling_mean_{window}'] = selected_with_local_time.groupby('rm_id')['net_weight'].transform(
        lambda x: x.shift(1).rolling(window=window, min_periods=1).mean()
    )
    selected_with_local_time[f'net_weight_rolling_std_{window}'] = selected_with_local_time.groupby('rm_id')['net_weight'].transform(
        lambda x: x.shift(1).rolling(window=window, min_periods=1).std()
    )

# Fill NaN values from lag/rolling features with 0
lag_rolling_cols = [col for col in selected_with_local_time.columns if 'lag_' in col or 'rolling_' in col]
selected_with_local_time[lag_rolling_cols] = selected_with_local_time[lag_rolling_cols].fillna(0)

print(f"Added {len(lag_rolling_cols)} lag and rolling features")

# Norwegian special days/holidays
# Fixed holidays
def get_norwegian_holidays(year):
    """Return dictionary of Norwegian holidays for a given year"""
    from datetime import timedelta
    
    holidays = {}
    
    # Fixed date holidays
    holidays[f'{year}-01-01'] = 'New Year'
    holidays[f'{year}-05-01'] = 'Labour Day'
    holidays[f'{year}-05-17'] = 'Constitution Day'
    holidays[f'{year}-12-24'] = 'Christmas Eve'
    holidays[f'{year}-12-25'] = 'Christmas Day'
    holidays[f'{year}-12-26'] = 'Boxing Day'
    holidays[f'{year}-12-31'] = 'New Year Eve'
    
    # Easter-based holidays (Easter dates vary each year)
    # Approximate Easter calculation (Meeus/Jones/Butcher algorithm)
    a = year % 19
    b = year // 100
    c = year % 100
    d = b // 4
    e = b % 4
    f = (b + 8) // 25
    g = (b - f + 1) // 3
    h = (19 * a + b - d - g + 15) % 30
    i = c // 4
    k = c % 4
    l = (32 + 2 * e + 2 * i - h - k) % 7
    m = (a + 11 * h + 22 * l) // 451
    month = (h + l - 7 * m + 114) // 31
    day = ((h + l - 7 * m + 114) % 31) + 1
    
    easter = pd.Timestamp(year=year, month=month, day=day)
    
    # Easter-related holidays
    holidays[(easter - timedelta(days=3)).strftime('%Y-%m-%d')] = 'Maundy Thursday'
    holidays[(easter - timedelta(days=2)).strftime('%Y-%m-%d')] = 'Good Friday'
    holidays[easter.strftime('%Y-%m-%d')] = 'Easter Sunday'
    holidays[(easter + timedelta(days=1)).strftime('%Y-%m-%d')] = 'Easter Monday'
    holidays[(easter + timedelta(days=39)).strftime('%Y-%m-%d')] = 'Ascension Day'
    holidays[(easter + timedelta(days=49)).strftime('%Y-%m-%d')] = 'Whit Sunday'
    holidays[(easter + timedelta(days=50)).strftime('%Y-%m-%d')] = 'Whit Monday'
    
    return holidays

# Create a mapping of all dates to holidays
all_holidays = {}
for year in range(selected_with_local_time['date_arrival'].dt.year.min(), 
                  selected_with_local_time['date_arrival'].dt.year.max() + 1):
    all_holidays.update(get_norwegian_holidays(year))

# Add special day column
selected_with_local_time['date_str'] = selected_with_local_time['date_arrival'].dt.strftime('%Y-%m-%d')
selected_with_local_time['special_days'] = selected_with_local_time['date_str'].map(all_holidays).fillna('none').astype('category')
selected_with_local_time.drop('date_str', axis=1, inplace=True)

# Add binary flag for whether it's a holiday
selected_with_local_time['is_holiday'] = (selected_with_local_time['special_days'] != 'none').astype(int)

special_days = list(all_holidays.values())

# Make rm_id a string instead of numeric
selected_with_local_time["rm_id"] = selected_with_local_time["rm_id"].astype(int).astype(str).astype("category")
selected_with_local_time["is_holiday"] = selected_with_local_time["is_holiday"].astype(str).astype("category")
selected_with_local_time.drop("year", axis=1, inplace=True)

Adding lag and rolling features...
Added 16 lag and rolling features
Added 16 lag and rolling features


In [41]:
######################### CREATE TIME SERIES DATASET FOR PYTORCH FORECASTING ##################################
full_data = selected_with_local_time.copy()

max_prediction_length = 151
max_encoder_length = 365
# V: training_cutoff = data["time_idx"].max() - max_prediction_length

training = TimeSeriesDataSet(
    # V: data[lambda x: x.local_time_idx <= training_cutoff],
    data = full_data,
    time_idx="local_time_idx",
    target="net_weight_cum",
    group_ids=["rm_id"],
    min_encoder_length=max_encoder_length
    // 2,  # keep encoder length long (as it is in the validation set)
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["rm_id"],
    #static_reals= no static real yet,
    time_varying_known_categoricals=["special_days", "month", "day_of_week", "is_holiday", "day_of_month", "week_of_year", "quarter"],
    #variable_groups={
    #    "special_days": special_days
    #},  # group of categorical variables can be treated as one variable
    time_varying_known_reals=["local_time_idx", "is_weekend", "is_month_start", "is_month_end"],
    # CAN PUT YEAR IN TIME_VARYING_KNOWN_REALS
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=[
        "quantity",
        "net_weight",
        "log_weight",
        "net_weight_cum",
        # Lag features (10 total)
        "net_weight_lag_7", "net_weight_lag_14", "net_weight_lag_28", "net_weight_lag_56", "net_weight_lag_91",
        "log_weight_lag_7", "log_weight_lag_14", "log_weight_lag_28", "log_weight_lag_56", "log_weight_lag_91",
        # Rolling features (6 total)
        "net_weight_rolling_mean_7", "net_weight_rolling_mean_28", "net_weight_rolling_mean_91",
        "net_weight_rolling_std_7", "net_weight_rolling_std_28", "net_weight_rolling_std_91",
    ],
    target_normalizer=GroupNormalizer(
        groups=["rm_id"], transformation="softplus"
    ),  # use softplus and normalize by group
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

# create validation set (predict=True) which means to predict the last max_prediction_length points in time
# for each series
#V: validation = TimeSeriesDataSet.from_dataset(
#V:    training, data, predict=True, stop_randomization=True
#V:)

# create dataloaders for model
batch_size = 128  # set this between 32 to 128
train_dataloader = training.to_dataloader(
    train=True, batch_size=batch_size, num_workers=0
)
#V: val_dataloader = validation.to_dataloader(
#V:    train=False, batch_size=batch_size * 10, num_workers=0

In [42]:
################# DECIDING ON THE MODEL AND TRAINER PARAMETERS ##########################
lr_logger = LearningRateMonitor()  # log the learning rate
logger = TensorBoardLogger("lightning_logs")  # logging results to a tensorboard

# {'gradient_clip_val': 0.012179320703577733, 'hidden_size': 29, 'dropout': 0.12469376786070158, 'hidden_continuous_size': 14, 'attention_head_size': 1, 'learning_rate': 0.004342145465626561}. Best is trial 2 with value: 449.2545166015625

# configure network and trainer
pl.seed_everything(42)
trainer = pl.Trainer(
    max_epochs=50,
    accelerator="gpu",
    enable_model_summary=True,
    gradient_clip_val=0.012179320703577733,
    limit_train_batches=30,
    #fast_dev_run = True,
    callbacks=[lr_logger],
    logger=logger,
)


tft = TemporalFusionTransformer.from_dataset(
    training,
    # not meaningful for finding the learning rate but otherwise very important
    learning_rate=0.004342145465626561,
    hidden_size=29,  # most important hyperparameter apart from learning rate
    # number of attention heads. Set to up to 4 for large datasets
    attention_head_size=1,
    dropout=0.12469376786070158,  # between 0.1 and 0.3 are good values
    hidden_continuous_size=14,  # set to <= hidden_size
    loss=QuantileLoss(),
    #optimizer="ranger", OPTIMIZER FOR FINDING BEST LEARNING RATE
    # reduce learning rate if no improvement in validation loss after x epochs
    reduce_on_plateau_patience=4,
)
print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")

Seed set to 42


ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Number of parameters in network: 119.8k


In [43]:
######## TRAINING THE MODEL ##########

trainer.fit(
    tft,
    train_dataloaders=train_dataloader
)

You are using a CUDA device ('NVIDIA GeForce RTX 3070') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 7.2 K  | train
3  | prescalers                         | ModuleDict                      | 784    | train
4  | static_variable_selection          | VariableSelectionNetwork        | 4.8 K  | train
5  | encoder_variable_selection         | VariableSelectionNetwork        | 52.5 K | train
6  | decoder_variable_selection         | VariableSelectionNetwork        | 10.2 K | train
7  | static_context_variable_selection  | GatedResidualNetwork            | 3.5 K  | train
8  | static_context_initial_hidden_lstm |

Epoch 49: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 30/30 [00:47<00:00,  0.64it/s, v_num=22, train_loss_step=1.36e+4, train_loss_epoch=2.83e+4]

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 30/30 [00:47<00:00,  0.63it/s, v_num=22, train_loss_step=1.36e+4, train_loss_epoch=2.83e+4]



In [44]:
################# FULLL PREDICTION FOR ALL RM_IDs (NEED ENCODER AND DECODER DATA) ###############


rm_ids = full_data['rm_id'].unique().tolist()
predict_data = []
# Create prediction date range
pred_start = pd.Timestamp('2025-01-01')
pred_end = pd.Timestamp('2025-05-31')
pred_dates = pd.date_range(start=pred_start, end=pred_end, freq='D')

all_predict_dfs = []

for rm_id in rm_ids:
    test_rm_id = rm_id  # must match categorical rm_id type
    historical = full_data[full_data['rm_id'] == test_rm_id].copy()
    if historical.empty:
        continue
    min_date = historical['date_arrival'].min()

    # build prediction rows for this rm_id
    rows = []
    for date in pred_dates:
        time_idx = (date - min_date).days
        date_str = date.strftime('%Y-%m-%d')
        special_day = all_holidays.get(date_str, 'none')
        is_holiday = '1' if special_day != 'none' else '0'

        rows.append({
            'rm_id': test_rm_id,
            'date_arrival': date,
            'local_time_idx': time_idx,
            'month': str(date.month),
            'day_of_week': str(date.dayofweek),
            'day_of_month': str(date.day),
            'week_of_year': str(date.isocalendar().week),
            'quarter': str(date.quarter),
            'is_weekend': int(date.dayofweek >= 5),
            'is_month_start': int(date.is_month_start),
            'is_month_end': int(date.is_month_end),
            'special_days': special_day,
            'is_holiday': is_holiday,
            'net_weight_cum': 0,  # placeholder
            'net_weight': 0,   # placeholder
            'quantity': 0,     # placeholder
            'log_weight': 0,   # placeholder
            # Lag and rolling features will be computed from historical data
            'net_weight_lag_7': 0,
            'net_weight_lag_14': 0,
            'net_weight_lag_28': 0,
            'net_weight_lag_56': 0,
            'net_weight_lag_91': 0,
            'log_weight_lag_7': 0,
            'log_weight_lag_14': 0,
            'log_weight_lag_28': 0,
            'log_weight_lag_56': 0,
            'log_weight_lag_91': 0,
            'net_weight_rolling_mean_7': 0,
            'net_weight_rolling_mean_28': 0,
            'net_weight_rolling_mean_91': 0,
            'net_weight_rolling_std_7': 0,
            'net_weight_rolling_std_28': 0,
            'net_weight_rolling_std_91': 0,
        })

    pred_df_rm = pd.DataFrame(rows)
    
    # Note: Lag and rolling features are already set to 0 in the rows above
    # They are unknown for future dates and will remain as placeholders
    
    # encoder/context data (last max_encoder_length days)
    encoder_data = historical.tail(max_encoder_length).copy()
    encoder_data['local_time_idx'] = encoder_data['local_time_idx'].astype(int)

    # combine encoder + prediction for this rm_id and collect
    combined = pd.concat([encoder_data, pred_df_rm], ignore_index=True)
    all_predict_dfs.append(combined)

# final combined prediction dataframe for all rm_ids
predict_data = pd.concat(all_predict_dfs, ignore_index=True)

predict_data = pd.DataFrame(predict_data)

# Convert to categorical to match training data
predict_data['rm_id'] = predict_data['rm_id'].astype(str).astype('category')
predict_data['month'] = predict_data['month'].astype(str).astype('category')
predict_data['day_of_week'] = predict_data['day_of_week'].astype(str).astype('category')
predict_data['day_of_month'] = predict_data['day_of_month'].astype(str).astype('category')
predict_data['week_of_year'] = predict_data['week_of_year'].astype(str).astype('category')
predict_data['quarter'] = predict_data['quarter'].astype(str).astype('category')
predict_data['special_days'] = predict_data['special_days'].astype(str).astype('category')
predict_data['is_holiday'] = predict_data['is_holiday'].astype(str).astype('category')

# Ensure local_time_idx is integer (required by TimeSeriesDataSet)
predict_data['local_time_idx'] = predict_data['local_time_idx'].astype(int)

print(f"Prediction dataframe shape: {predict_data.shape}")
print(f"Date range: {predict_data['date_arrival'].min()} to {predict_data['date_arrival'].max()}")
print(f"Time index range: {predict_data['local_time_idx'].min()} to {predict_data['local_time_idx'].max()}")
print("\nFirst few rows:")
print(predict_data.head())

Prediction dataframe shape: (104748, 33)
Date range: 2024-01-02 00:00:00 to 2025-05-31 00:00:00
Time index range: 1 to 7655

First few rows:
   local_time_idx rm_id date_arrival  net_weight  quantity  net_weight_cum  \
0            7132   342   2024-01-02         0.0       0.0             0.0   
1            7133   342   2024-01-03         0.0       0.0             0.0   
2            7134   342   2024-01-04         0.0       0.0             0.0   
3            7135   342   2024-01-05         0.0       0.0             0.0   
4            7136   342   2024-01-06         0.0       0.0             0.0   

  month day_of_week day_of_month week_of_year  ... net_weight_lag_91  \
0     1           1            2            1  ...               0.0   
1     1           2            3            1  ...               0.0   
2     1           3            4            1  ...               0.0   
3     1           4            5            1  ...               0.0   
4     1           5           

In [52]:
################## MAKE PREDICTIONS #######################

predictions = tft.predict(predict_data, mode="raw", return_x=True, return_index=True, return_decoder_lengths=True)

ltdx_and_rmid = predictions.index

output = predictions.output.prediction[:,:,0]

pred = []

pred_start = pd.Timestamp('2025-01-01')
pred_end = pd.Timestamp('2025-05-31')
pred_dates = pd.date_range(start=pred_start, end=pred_end, freq='D')


for rm_id_index in range(0,203):
    rm_id_test = ltdx_and_rmid["rm_id"][rm_id_index]
    ltdx_test = ltdx_and_rmid["local_time_idx"][rm_id_index]
    for date in pred_dates:
        pred_weight = output[rm_id_index][(date-pred_start).days].item()
        pred.append({
            "rm_id": rm_id_test,
            "local_time_idx": ltdx_test,
            "date": date,
            "predicted_weight": pred_weight
        })


pred = pd.DataFrame(pred)
pred_over_0 = pred[pred["predicted_weight"]>0]

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [53]:
################## PREPARING THE SUBMISSION FILE #######################
sample_submission = pd.read_csv("../data/sample_submission.csv")
prediction_mapping = pd.read_csv("../data/prediction_mapping.csv", parse_dates=["forecast_start_date", "forecast_end_date"])

submission = sample_submission.merge(prediction_mapping, on="ID")
submission["forecast_end_date"] = pd.to_datetime(submission["forecast_end_date"])
submission["forecast_start_date"] = pd.to_datetime(submission["forecast_start_date"])

for p in pred_over_0.itertuples():
    rm_id = p.rm_id
    date_arrival = p.date.replace(tzinfo=None)
    predicted_weight = p.predicted_weight
    submission.loc[
        (submission['rm_id'] == int(rm_id)) & (submission['forecast_end_date'] >= date_arrival),
        'predicted_weight'
    ] = predicted_weight

In [57]:
test_submission = submission[["ID", "predicted_weight"]]
test_submission.to_csv("submission_tft_cumulative_lag_rolling.csv", index=False)

In [None]:
############## PRINT THE MODEL's TOTAL PREDICTIONS PER RM_ID ##############
filtered = submission.copy()

agg_df = filtered.groupby("rm_id", as_index=False).agg({
    "predicted_weight": "max",
}).sort_values("predicted_weight", ascending=False)


print("TOTAL PREDICTED WEIGHTS PER RM_ID WITH THE ML MODEL:")
print(agg_df[agg_df["predicted_weight"]>0])

TOTAL PREDICTED WEIGHTS PER RM_ID WITH THE ML MODEL:
     rm_id  predicted_weight
176   3781      4.587214e+06
180   3865      3.330814e+06
151   3126      2.402842e+06
150   3125      1.392744e+06
147   3122      1.358385e+06
160   3282      1.148517e+06
149   3124      9.164329e+05
148   3123      7.534633e+05
182   3901      3.643884e+05
85    2142      2.133031e+05
79    2134      1.293509e+05
87    2144      1.251198e+05
80    2135      1.136899e+05
159   3265      8.336219e+04
181   3883      7.053468e+04
172   3642      6.301469e+04
163   3421      4.528061e+04
185   4021      4.474627e+04
74    2129      4.171755e+04
191   4263      3.826533e+04
76    2131      3.327738e+04
190   4222      3.219623e+04
162   3381      2.636178e+04
187   4081      2.336391e+04
171   3621      2.255507e+04
197   4443      1.979580e+04
186   4044      1.314957e+04
169   3581      1.161203e+04
192   4302      9.161141e+03
195   4401      7.290643e+03
71    2125      5.578164e+03
193   4343      4.2