# Fine-tune TFT on US MMS Data
This notebook fine-tunes a pretrained Temporal Fusion Transformer using the US MMS dataset.

In [1]:
!pip install wandb

Collecting wandb
  Using cached wandb-0.20.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 (from wandb)
  Using cached protobuf-6.31.1-cp39-abi3-manylinux2014_x86_64.whl.metadata (593 bytes)
Collecting sentry-sdk>=2.0.0 (from wandb)
  Using cached sentry_sdk-2.29.1-py2.py3-none-any.whl.metadata (10 kB)
Collecting setproctitle (from wandb)
  Using cached setproctitle-1.3.6-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Using cached wandb-0.20.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (23.2 MB)
Using cached protobuf-6.31.1-cp39-abi3-manylinux2014_x86_64.whl (321 kB)
Using cached sentry_sdk-2.29.1-py2.py3-none-any.whl (341 kB)
Using cached setproctitle-1.3.6-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB)
Installing collected packages: setproctitle, sentry-sdk, protobuf, wan

In [2]:
!pip install tft-torch

Collecting tft-torch
  Using cached tft_torch-0.0.6-py3-none-any.whl.metadata (5.8 kB)
Using cached tft_torch-0.0.6-py3-none-any.whl (21 kB)
Installing collected packages: tft-torch
Successfully installed tft-torch-0.0.6


In [3]:
import argparse
import pickle
from typing import Dict, List, Tuple
from functools import partial
import copy
import numpy as np
from omegaconf import OmegaConf, DictConfig
import pandas as pd
from tqdm import tqdm
import torch
from torch import optim
from torch import nn
import torch.nn.init as init
from torch.utils.data import Dataset, DataLoader, Subset
from tft_torch.tft import TemporalFusionTransformer
import tft_torch.loss as tft_loss
import json
import os 
import wandb
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [4]:
filename= "/home/jovyan/phenology-ml-clm/data/sorted_BDT_50_20_merged_1982_2021_US_MMS.pkl"
checkpoint= "/home/jovyan/phenology-ml-clm/docs/weights_merged_BDT_1982_2021_feb2025_checkpoint.pth"
output_path= "/home/jovyan/phenology-ml-clm/data/US_MMS_finetuned_from_feb_chkpt_060725.pth"

In [19]:
configuration = {'optimization':
                 {
                     'batch_size': {'training': 8, 'inference': 8},# both weere 64 before
                     'learning_rate': 1e-4,#was 0.001
                     'max_grad_norm': 1.0,
                 }
                 ,
                 'model':
                 {
                     'dropout': 0.2,#was 0.05 before
                     'state_size': 160,
                     'output_quantiles': [0.1, 0.5, 0.9],
                     'lstm_layers': 4,#was 2
                     'attention_heads': 4 #was 4 #then 6
                 },
                 # these arguments are related to possible extensions of the model class
                 'task_type':'regression',
                 'target_window_start': None, 
                 'training_data':  filename,
                'checkpoint': checkpoint, 
                'output_path':output_path, }

wandb.init(project="TL_US_MMS", config = configuration)

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mayalahlou[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
with open(filename,'rb') as fp:
        data = pickle.load(fp)
    
feature_map = data['feature_map']
cardinalities_map = data['categorical_cardinalities']

structure = {
        'num_historical_numeric': len(feature_map['historical_ts_numeric']),
        'num_historical_categorical': len(feature_map['historical_ts_categorical']),
        'num_static_numeric': len(feature_map['static_feats_numeric']),
        'num_static_categorical': len(feature_map['static_feats_categorical']),
        'num_future_numeric': len(feature_map['future_ts_numeric']),
        'num_future_categorical': len(feature_map['future_ts_categorical']),
        'historical_categorical_cardinalities': [cardinalities_map[feat] + 1 for feat in feature_map['historical_ts_categorical']],
        'static_categorical_cardinalities': [cardinalities_map[feat] + 1 for feat in feature_map['static_feats_categorical']],
        'future_categorical_cardinalities': [cardinalities_map[feat] + 1 for feat in feature_map['future_ts_categorical']],
    }

configuration['data_props'] = structure



# random

In [12]:




def weight_init(m):
    """
    Usage:
        model = Model()
        model.apply(weight_init)
    """
    if isinstance(m, nn.Conv1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.ConvTranspose3d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm3d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.LSTM):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.LSTMCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
    elif isinstance(m, nn.GRU):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
        for names in m._all_weights:
            for name in filter(lambda n: "bias" in n, names):
                bias = getattr(m, name)
                n = bias.size(0)
                bias.data[:n // 3].fill_(-1.)
    elif isinstance(m, nn.GRUCell):
        for param in m.parameters():
            if len(param.shape) >= 2:
                init.orthogonal_(param.data)
            else:
                init.normal_(param.data)
class DictDataSet(Dataset):
    def __init__(self, array_dict: Dict[str, np.ndarray]):
        self.keys_list = []
        for k, v in array_dict.items():
            self.keys_list.append(k)
            if np.issubdtype(v.dtype, np.dtype('bool')):
                setattr(self, k, torch.ByteTensor(v))
            elif np.issubdtype(v.dtype, np.int8):
                setattr(self, k, torch.CharTensor(v))
            elif np.issubdtype(v.dtype, np.int16):
                setattr(self, k, torch.ShortTensor(v))
            elif np.issubdtype(v.dtype, np.int32):
                setattr(self, k, torch.IntTensor(v))
            elif np.issubdtype(v.dtype, np.int64):
                setattr(self, k, torch.LongTensor(v))
            elif np.issubdtype(v.dtype, np.float32):
                setattr(self, k, torch.FloatTensor(v))
            elif np.issubdtype(v.dtype, np.float64):
                setattr(self, k, torch.DoubleTensor(v))
            else:
                setattr(self, k, torch.FloatTensor(v))

    def __getitem__(self, index):
        return {k: getattr(self, k)[index] for k in self.keys_list}

    def __len__(self):
        return getattr(self, self.keys_list[0]).shape[0]
                
                
def recycle(iterable):
    while True:
        for x in iterable:
            yield x

def get_set_and_loaders(data_dict: Dict[str, np.ndarray],
                        shuffled_loader_config: Dict,
                        serial_loader_config: Dict,
                        ignore_keys: List[str] = None,
                        ) -> Tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    dataset = DictDataSet({k:v for k,v in data_dict.items() if (ignore_keys and k not in ignore_keys)})
    loader = torch.utils.data.DataLoader(dataset,**shuffled_loader_config)
    serial_loader = torch.utils.data.DataLoader(dataset,**serial_loader_config)

    return dataset,iter(recycle(loader)),serial_loader

class QueueAggregator(object):
    def __init__(self, max_size):
        self._queued_list = []
        self.max_size = max_size

    def append(self, elem):
        self._queued_list.append(elem)
        if len(self._queued_list) > self.max_size:
            self._queued_list.pop(0)

    def get(self):
        return self._queued_list
    
class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if torch.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)
                
def process_batch(batch: Dict[str,torch.tensor],
                  model: nn.Module,
                  quantiles_tensor: torch.tensor,
                  device:torch.device):
    if is_cuda:
        for k in list(batch.keys()):
            batch[k] = batch[k].to(device)

    batch_outputs = model(batch)
    labels = batch['target']

    predicted_quantiles = batch_outputs['predicted_quantiles']
    q_loss, q_risk, _ = tft_loss.get_quantiles_loss_and_q_risk(outputs=predicted_quantiles,
                                                              targets=labels,
                                                              desired_quantiles=quantiles_tensor)
    return q_loss, q_risk



# Load model 

In [7]:
model = TemporalFusionTransformer(config=OmegaConf.create(configuration))

In [8]:
is_cuda = torch.cuda.is_available()
device = torch.device("cuda" if is_cuda else "cpu")

model.to(device)

TemporalFusionTransformer(
  (static_transform): InputChannelEmbedding(
    (numeric_transform): NumericInputTransformation(
      (numeric_projection_layers): ModuleList(
        (0-1): 2 x Linear(in_features=1, out_features=160, bias=True)
      )
    )
    (categorical_transform): NullTransform()
  )
  (historical_ts_transform): InputChannelEmbedding(
    (numeric_transform): TimeDistributed(
      (module): NumericInputTransformation(
        (numeric_projection_layers): ModuleList(
          (0-6): 7 x Linear(in_features=1, out_features=160, bias=True)
        )
      )
    )
    (categorical_transform): NullTransform()
  )
  (future_ts_transform): InputChannelEmbedding(
    (numeric_transform): TimeDistributed(
      (module): NumericInputTransformation(
        (numeric_projection_layers): ModuleList(
          (0-5): 6 x Linear(in_features=1, out_features=160, bias=True)
        )
      )
    )
    (categorical_transform): NullTransform()
  )
  (static_selection): VariableSelec

In [9]:
state_dict = torch.load(checkpoint, map_location=device)
model_state = model.state_dict()

# Filter out incompatible keys
filtered_state_dict = {
    k: v for k, v in state_dict.items() if k in model_state and model_state[k].shape == v.shape
}

model.load_state_dict(filtered_state_dict, strict=False)

  state_dict = torch.load(checkpoint, map_location=device)


_IncompatibleKeys(missing_keys=['static_selection.flattened_grn.skip_layer.module.weight', 'static_selection.flattened_grn.skip_layer.module.bias', 'static_selection.flattened_grn.fc1.module.weight', 'static_selection.flattened_grn.fc2.module.weight', 'static_selection.flattened_grn.fc2.module.bias', 'static_selection.flattened_grn.gate.module.fc1.weight', 'static_selection.flattened_grn.gate.module.fc1.bias', 'static_selection.flattened_grn.gate.module.fc2.weight', 'static_selection.flattened_grn.gate.module.fc2.bias', 'static_selection.flattened_grn.layernorm.module.weight', 'static_selection.flattened_grn.layernorm.module.bias'], unexpected_keys=[])

In [10]:
model.to(device)
model.eval()

TemporalFusionTransformer(
  (static_transform): InputChannelEmbedding(
    (numeric_transform): NumericInputTransformation(
      (numeric_projection_layers): ModuleList(
        (0-1): 2 x Linear(in_features=1, out_features=160, bias=True)
      )
    )
    (categorical_transform): NullTransform()
  )
  (historical_ts_transform): InputChannelEmbedding(
    (numeric_transform): TimeDistributed(
      (module): NumericInputTransformation(
        (numeric_projection_layers): ModuleList(
          (0-6): 7 x Linear(in_features=1, out_features=160, bias=True)
        )
      )
    )
    (categorical_transform): NullTransform()
  )
  (future_ts_transform): InputChannelEmbedding(
    (numeric_transform): TimeDistributed(
      (module): NumericInputTransformation(
        (numeric_projection_layers): ModuleList(
          (0-5): 6 x Linear(in_features=1, out_features=160, bias=True)
        )
      )
    )
    (categorical_transform): NullTransform()
  )
  (static_selection): VariableSelec

In [11]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
criterion = nn.MSELoss()

In [23]:
opt = optim.Adam(filter(lambda p: p.requires_grad, list(model.parameters())),
                    lr=configuration['optimization']['learning_rate'],
                    weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=opt, mode='min', factor = 0.5, patience=5, min_lr= 1e-6 )


In [17]:
shuffled_loader_config = {'batch_size': configuration['optimization']['batch_size']['training'],
                    'drop_last': True,
                    'shuffle':False}

serial_loader_config = {'batch_size': configuration['optimization']['batch_size']['inference'],
                'drop_last': False,
                'shuffle':False}

# the following fields do not contain actual data, but are only identifiers of each observation
#meta_keys = ['time', 'location','soil', "soil_x", "soil_y", "id"]
meta_keys = ['time', 'location', "soil_x", "soil_y", "id"]
train_set,train_loader,train_serial_loader = get_set_and_loaders(data['data_sets']['train'],
                                                                shuffled_loader_config,
                                                                serial_loader_config,
                                                                ignore_keys=meta_keys)
validation_set,validation_loader,validation_serial_loader = get_set_and_loaders(data['data_sets']['validation'],
                                                                shuffled_loader_config,
                                                                serial_loader_config,
                                                                ignore_keys=meta_keys)
test_set,test_loader,test_serial_loader = get_set_and_loaders(data['data_sets']['test'],
                                                                serial_loader_config,
                                                                serial_loader_config,
                                                                ignore_keys=meta_keys)

# If early stopping is not triggered, after how many epochs should we quit training
max_epochs = 100
# how many training batches will compose a single training epoch
epoch_iters = len(data['data_sets']['train']['time_index'])//8#was 200 #then 400
# upon completing a training epoch, we perform an evaluation of all the subsets
# eval_iters will define how many batches of each set will compose a single evaluation round
eval_iters = len(data['data_sets']['validation']['time_index'])//8 #500 #then 100
# during training, on what frequency should we display the monitored performance
log_interval = 50
# what is the running-window used by our QueueAggregator object for monitoring the training performance
ma_queue_size = 50
# how many evaluation rounds should we allow,
# without any improvement in the performance observed on the validation set
patience = 10
    
# initialize early stopping mechanism
es = EarlyStopping(patience=patience)
# initialize the loss aggregator for running window performance estimation
loss_aggregator = QueueAggregator(max_size=ma_queue_size)

# initialize counters
batch_idx = 0
epoch_idx = 0

quantiles_tensor = torch.tensor(configuration['model']['output_quantiles']).to(device)


In [24]:
while epoch_idx < max_epochs:
        print(f"Starting Epoch Index {epoch_idx}")

        # evaluation round
        model.eval()
        with torch.no_grad():
            # for each subset
            for subset_name, subset_loader in zip(['train','validation','test'],[train_loader,validation_loader,test_loader]):
                print(f"Evaluating {subset_name} set")

                q_loss_vals, q_risk_vals = [],[] # used for aggregating performance along the evaluation round
                for v in range(eval_iters):
                    #print(v)
                    # get batch
                    batch = next(subset_loader)
                    #batch = [item.to(device) for item in batch]
                    # process batch
                    batch_loss,batch_q_risk = process_batch(batch=batch,model=model,quantiles_tensor=quantiles_tensor,device=device)
                    # accumulate performance
                    q_loss_vals.append(batch_loss)
                    q_risk_vals.append(batch_q_risk)
                #print('done')
                # aggregate and average
                eval_loss = torch.stack(q_loss_vals).mean(axis=0)
                eval_q_risk = torch.stack(q_risk_vals,axis=0).mean(axis=0)

                # keep for feeding the early stopping mechanism
                if subset_name == 'validation':
                    validation_loss = eval_loss

                # log performance
                print(f"Epoch: {epoch_idx}, Batch Index: {batch_idx}" + \
                    f"- Eval {subset_name} - " + \
                    f"q_loss = {eval_loss:.5f} , " + \
                    " , ".join([f"q_risk_{q:.1} = {risk:.5f}" for q,risk in zip(quantiles_tensor,eval_q_risk)]))

                # log metrics to wandb
                wandb.log({"eval_q_loss": eval_loss})
        # switch to training mode
        model.train()

        # update early stopping mechanism and stop if triggered
        if es.step(validation_loss):
            print('Performing early stopping...!')
            break

        # initiating a training round
        for _ in range(epoch_iters):
            # get training batch
            batch = next(train_loader)

            opt.zero_grad()
            # process batch
            loss,_ = process_batch(batch=batch,
                                model=model,
                                quantiles_tensor=quantiles_tensor,
                                device=device)
            
            # Debug gradient norms
            total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # compute gradients
            loss.backward()
            # update weights
            opt.step()
            # gradient clipping
            
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)


            # accumulate performance
            loss_aggregator.append(loss.item())

            # log performance
            if batch_idx % log_interval == 0:
                print(f"Epoch: {epoch_idx}, Batch Index: {batch_idx} - Train Loss = {np.mean(loss_aggregator.get())}")

            # completed batch
            batch_idx += 1
            
        scheduler.step(validation_loss)
        
        #Log the learning rate
        current_lr = opt.param_groups[0]['lr']

        wandb.log({"learning_rate": current_lr})
            
        # log metrics to wandb
        wandb.log({"train_loss": np.mean(loss_aggregator.get())})

        # completed epoch
        epoch_idx += 1

Starting Epoch Index 0
Evaluating train set
Epoch: 0, Batch Index: 77- Eval train - q_loss = 0.03762 , q_risk_0.1 = 0.01942 , q_risk_0.5 = 0.04138 , q_risk_0.9 = 0.02063
Evaluating validation set
Epoch: 0, Batch Index: 77- Eval validation - q_loss = 0.03371 , q_risk_0.1 = 0.01723 , q_risk_0.5 = 0.03752 , q_risk_0.9 = 0.02010
Evaluating test set
Epoch: 0, Batch Index: 77- Eval test - q_loss = 0.03190 , q_risk_0.1 = 0.01803 , q_risk_0.5 = 0.03425 , q_risk_0.9 = 0.01925
Epoch: 0, Batch Index: 100 - Train Loss = 0.042504263184964655
Epoch: 0, Batch Index: 150 - Train Loss = 0.030155572667717935
Starting Epoch Index 1
Evaluating train set
Epoch: 1, Batch Index: 154- Eval train - q_loss = 0.02079 , q_risk_0.1 = 0.00957 , q_risk_0.5 = 0.02376 , q_risk_0.9 = 0.01354
Evaluating validation set
Epoch: 1, Batch Index: 154- Eval validation - q_loss = 0.02428 , q_risk_0.1 = 0.01234 , q_risk_0.5 = 0.02923 , q_risk_0.9 = 0.01397
Evaluating test set
Epoch: 1, Batch Index: 154- Eval test - q_loss = 0.02

In [25]:
# Save model weights to a file
print('Training over.')
torch.save(model.state_dict(), output_path)
wandb.finish()


Training over.


0,1
eval_q_loss,██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
learning_rate,████████████████▃▃▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁
train_loss,█▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▂▁▁▂▁▁▁

0,1
eval_q_loss,0.01159
learning_rate,3e-05
train_loss,0.01164


# Predict 

In [27]:
prediction_path="/home/jovyan/phenology-ml-clm/data/predictions/USMMS_06082025.pkl"
model.eval() # switch to evaluation mode


TemporalFusionTransformer(
  (static_transform): InputChannelEmbedding(
    (numeric_transform): NumericInputTransformation(
      (numeric_projection_layers): ModuleList(
        (0-1): 2 x Linear(in_features=1, out_features=160, bias=True)
      )
    )
    (categorical_transform): NullTransform()
  )
  (historical_ts_transform): InputChannelEmbedding(
    (numeric_transform): TimeDistributed(
      (module): NumericInputTransformation(
        (numeric_projection_layers): ModuleList(
          (0-6): 7 x Linear(in_features=1, out_features=160, bias=True)
        )
      )
    )
    (categorical_transform): NullTransform()
  )
  (future_ts_transform): InputChannelEmbedding(
    (numeric_transform): TimeDistributed(
      (module): NumericInputTransformation(
        (numeric_projection_layers): ModuleList(
          (0-5): 6 x Linear(in_features=1, out_features=160, bias=True)
        )
      )
    )
    (categorical_transform): NullTransform()
  )
  (static_selection): VariableSelec

In [30]:
output_aggregator = dict() # will be used for aggregating the outputs across batches
with torch.no_grad():
    # go over the batches of the serial data loader
    for batch in tqdm(test_serial_loader):# change this from validation serial loader
        # process each batch
        if is_cuda:
            for k in list(batch.keys()):
                batch[k] = batch[k].to(device)
        batch_outputs = model(batch)

        # accumulate outputs, as well as labels
        for output_key,output_tensor in batch_outputs.items():
            output_aggregator.setdefault(output_key,[]).append(output_tensor.cpu().numpy())
        
validation_outputs = dict()
for k in list(output_aggregator.keys()):
    validation_outputs[k] = np.concatenate(output_aggregator[k],axis=0)

# Save the dictionary using Pickle
with open(prediction_path, "wb") as pickle_file:
    print("saving in", prediction_path)
    pickle.dump(validation_outputs, pickle_file)

100%|██████████| 13/13 [00:00<00:00, 59.54it/s]

saving in /home/jovyan/phenology-ml-clm/data/predictions/USMMS_06082025.pkl



