In [1]:
import torch
import os
from tft_torch.tft import TemporalFusionTransformer
import pickle
from omegaconf import OmegaConf, DictConfig
import torch
import tft_torch

print(tft_torch.__file__)
checkpoint_path = "/glade/u/home/ayal/phenology-ml-clm/models/tft_scripted.pt"
data_path = "/glade/u/home/ayal/phenology-ml-clm/data/USMMS_to_tft_06132025.pkl"

configuration = {'optimization':
                {
                    'batch_size': {'training': 8, 'inference': 8},# both were 64 before
                    'learning_rate': 1e-4,#was 0.001
                    'max_grad_norm': 1.0,
                }
                ,
                'model':
                {
                    'dropout': 0.2,#was 0.05 
                    'state_size': 160,
                    'output_quantiles': [0.1, 0.5, 0.9],
                    'lstm_layers': 4,#was 2
                    'attention_heads': 4 
                },
                # these arguments are related to possible extensions of the model class
                'task_type':'regression',
                'target_window_start': None, 
                'checkpoint': checkpoint_path}

#load the data
with open(data_path,'rb') as fp:
        data = pickle.load(fp)
        
feature_map = data['feature_map']
cardinalities_map = data['categorical_cardinalities']


structure = {
            'num_historical_numeric': len(feature_map['historical_ts_numeric']),
            'num_historical_categorical': len(feature_map['historical_ts_categorical']),
            'num_static_numeric': len(feature_map['static_feats_numeric']),
            'num_static_categorical': len(feature_map['static_feats_categorical']),
            'num_future_numeric': len(feature_map['future_ts_numeric']),
            'num_future_categorical': len(feature_map['future_ts_categorical']),
            'historical_categorical_cardinalities': [cardinalities_map[feat] + 1 for feat in feature_map['historical_ts_categorical']],
            'static_categorical_cardinalities': [cardinalities_map[feat] + 1 for feat in feature_map['static_feats_categorical']],
            'future_categorical_cardinalities': [cardinalities_map[feat] + 1 for feat in feature_map['future_ts_categorical']],
        }

configuration['data_props'] = structure



/glade/u/home/ayal/tft-torch/tft_torch/__init__.py


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [2]:
configuration



#model = TemporalFusionTransformer(config=OmegaConf.create(configuration))

{'optimization': {'batch_size': {'training': 8, 'inference': 8},
  'learning_rate': 0.0001,
  'max_grad_norm': 1.0},
 'model': {'dropout': 0.2,
  'state_size': 160,
  'output_quantiles': [0.1, 0.5, 0.9],
  'lstm_layers': 4,
  'attention_heads': 4},
 'task_type': 'regression',
 'target_window_start': None,
 'checkpoint': '/glade/u/home/ayal/phenology-ml-clm/models/tft_scripted.pt',
 'data_props': {'num_historical_numeric': 8,
  'num_historical_categorical': 0,
  'num_static_numeric': 2,
  'num_static_categorical': 0,
  'num_future_numeric': 1,
  'num_future_categorical': 0,
  'historical_categorical_cardinalities': [],
  'static_categorical_cardinalities': [],
  'future_categorical_cardinalities': []}}

In [3]:
OmegaConf.create(configuration)

{'optimization': {'batch_size': {'training': 8, 'inference': 8}, 'learning_rate': 0.0001, 'max_grad_norm': 1.0}, 'model': {'dropout': 0.2, 'state_size': 160, 'output_quantiles': [0.1, 0.5, 0.9], 'lstm_layers': 4, 'attention_heads': 4}, 'task_type': 'regression', 'target_window_start': None, 'checkpoint': '/glade/u/home/ayal/phenology-ml-clm/models/tft_scripted.pt', 'data_props': {'num_historical_numeric': 8, 'num_historical_categorical': 0, 'num_static_numeric': 2, 'num_static_categorical': 0, 'num_future_numeric': 1, 'num_future_categorical': 0, 'historical_categorical_cardinalities': [], 'static_categorical_cardinalities': [], 'future_categorical_cardinalities': []}}

In [4]:
configuration_test = {
    "task_type": "regression",
    "target_window_start": None,
    "data_props": {'num_historical_numeric': 8, #tmin, tmax, prcp, srad, swc, photoperiod, doy, lai,
  'num_historical_categorical': 0,
  'num_static_numeric': 2, # lat, lon
  'num_static_categorical': 0,
  'num_future_numeric': 1, # doy
  'num_future_categorical': 0,
  'historical_categorical_cardinalities': [],
  'static_categorical_cardinalities': [],
  'future_categorical_cardinalities': []},
    "model": {
        "attention_heads": 4,
        "dropout": 0.2,
        "lstm_layers": 4,
        "state_size": 64,
        "output_quantiles": [0.1, 0.5, 0.9],
    },
}


In [10]:
model = TemporalFusionTransformer(config=OmegaConf.create(configuration_test))

In [14]:
is_cuda = torch.cuda.is_available()
device = torch.device("cuda" if is_cuda else "cpu")
model = torch.jit.load(checkpoint_path, map_location="cpu")


In [15]:
model.eval()

RecursiveScriptModule(original_name=TemporalFusionTransformer)

In [None]:
import torch

B, T_hist, T_fut = 8, 60, 10

# static
static_num = torch.randn(B, 2)        #lat,long   
static_cat = torch.empty(B, 0, dtype=torch.long)
# historical
hist_num   = torch.randn(B, T_hist, 4)
hist_cat   = torch.randint(0, 5, (B, T_hist, 2), dtype=torch.long)
# future
fut_num    = torch.randn(B, T_fut, 1)
fut_cat    = torch.randint(0, 7, (B, T_fut, 1), dtype=torch.long)


"""'data_props': {'num_historical_numeric': 8,
  'num_historical_categorical': 0,
  'num_static_numeric': 2,
  'num_static_categorical': 0,
  'num_future_numeric': 1,
  'num_future_categorical': 0,
  'historical_categorical_cardinalities': [],
  'static_categorical_cardinalities': [],
  'future_categorical_cardinalities': []}}"""



# load and run your scripted model
scripted = torch.jit.load("tft_scripted.pt")
scripted.eval()
with torch.no_grad():
    # returns (B, T_fut, 3)
    out = scripted(static_num,
                   static_cat,
                   hist_num,
                   hist_cat,
                   fut_num,
                   fut_cat)

print(out.shape)  # torch.Size([16, 12, 3])


In [None]:
# suppose you have inputs as a Tensor or a tuple of Tensors:
with torch.no_grad():
    out = scripted_model(my_input_tensor)