In [6]:
import torch
import os
import pickle
import torch

"""
feature map is like this:
{'static_feats_numeric': ['latitude', 'longitude'],
 'static_feats_categorical': [],
 'historical_ts_numeric': ['tmin',
  'tmax',
  'precipitation',
  'radiation',
  'photoperiod',
  'swvl1',
  'doy',
  'sif_clear_inst'],
 'historical_ts_categorical': [],
 'future_ts_numeric': ['doy'],
 'future_ts_categorical': []}
"""
#### DATA 
B, T_hist, T_fut = 1, 60, 10
# static
static_num = torch.randn(B, 2)        #lat,long   
static_cat = torch.empty(B, 0, dtype=torch.long)
# historical
hist_num   = torch.randn(B, T_hist, 8) # tmin, tmax, prcp, srad, swc, photoperiod, doy, lai
hist_cat   = torch.empty(B, 0, dtype=torch.long)
# future
fut_num    = torch.randn(B, T_fut, 1)  
fut_cat    = torch.empty(B, 0, dtype=torch.long)

#### Model Path
checkpoint_path = "/glade/u/home/ayal/phenology-ml-clm/models/tft_scripted.pt"
#### Model Configuration
configuration_test = {
    "task_type": "regression",
    "target_window_start": None,
    "data_props": {'num_historical_numeric': 8, #tmin, tmax, prcp, srad, swc, photoperiod, doy, lai,
  'num_historical_categorical': 0,
  'num_static_numeric': 2, # lat, lon
  'num_static_categorical': 0,
  'num_future_numeric': 1, # doy
  'num_future_categorical': 0,
  'historical_categorical_cardinalities': [],
  'static_categorical_cardinalities': [],
  'future_categorical_cardinalities': []},
    "model": {
        "attention_heads": 4,
        "dropout": 0.2,
        "lstm_layers": 4,
        "state_size": 160,
        "output_quantiles": [0.1, 0.5, 0.9],
    },
}

### Load the model
model = torch.jit.load(checkpoint_path, map_location="cpu")
model.eval()

with torch.no_grad():
    # returns (B, T_fut, 3)
    out = model(static_num,
                   static_cat,
                   hist_num,
                   hist_cat,
                   fut_num,
                   fut_cat)
    
out[:, :, 1]


tensor([[-0.7310, -0.7871, -0.7402, -0.6085, -0.7669, -0.8404, -0.6197, -0.8075,
         -0.7866, -0.8444]])

In [None]:
import torch
import os
import pickle
import torch

"""
feature map is like this:
{'static_feats_numeric': ['latitude', 'longitude'],
 'static_feats_categorical': [],
 'historical_ts_numeric': ['tmin',
  'tmax',
  'precipitation',
  'radiation',
  'photoperiod',
  'swvl1',
  'doy',
  'sif_clear_inst'],
 'historical_ts_categorical': [],
 'future_ts_numeric': ['doy'],
 'future_ts_categorical': []}
"""
#### DATA 
B, T_hist, T_fut = 1, 60, 10
# static
static_num = torch.randn(B, 2)        #lat,long   
static_cat = torch.empty(B, 0, dtype=torch.long)
# historical
hist_num   = torch.randn(B, T_hist, 8) # tmin, tmax, prcp, srad, swc, photoperiod, doy, lai
hist_cat   = torch.empty(B, 0, dtype=torch.long)
# future
fut_num    = torch.randn(B, T_fut, 1)  
fut_cat    = torch.empty(B, 0, dtype=torch.long)

#### Model Path
checkpoint_path = "/glade/u/home/ayal/phenology-ml-clm/models/tft_scripted.pt"
#### Model Configuration
configuration_test = {
    "task_type": "regression",
    "target_window_start": None,
    "data_props": {'num_historical_numeric': 8, #tmin, tmax, prcp, srad, swc, photoperiod, doy, lai,
  'num_historical_categorical': 0,
  'num_static_numeric': 2, # lat, lon
  'num_static_categorical': 0,
  'num_future_numeric': 1, # doy
  'num_future_categorical': 0,
  'historical_categorical_cardinalities': [],
  'static_categorical_cardinalities': [],
  'future_categorical_cardinalities': []},
    "model": {
        "attention_heads": 4,
        "dropout": 0.2,
        "lstm_layers": 4,
        "state_size": 160,
        "output_quantiles": [0.1, 0.5, 0.9],
    },
}

### Load the model
model = torch.jit.load(checkpoint_path, map_location="cpu")
#model.eval()



RecursiveScriptModule(original_name=TemporalFusionTransformer)

In [4]:
from typing import Tuple
class OneArgWrapper(torch.nn.Module):
    def __init__(self, model: torch.jit.ScriptModule):
        super().__init__()
        self.model = model

    def forward(
        self,
        all_inputs: Tuple[
            torch.Tensor, torch.Tensor,
            torch.Tensor, torch.Tensor,
            torch.Tensor, torch.Tensor
        ]
    ) -> torch.Tensor:
        s_num, s_cat, h_num, h_cat, f_num, f_cat = all_inputs
        return self.model(s_num, s_cat, h_num, h_cat, f_num, f_cat)

In [6]:
wrapper = OneArgWrapper(torch.jit.load(checkpoint_path, map_location="cpu"))
wrapper.eval()
scripted = torch.jit.script(wrapper)
model_frozen = torch.jit.freeze(scripted)
out_file= "/glade/u/home/ayal/phenology-ml-clm/models/tft_onearg.pt"
torch.jit.save(model_frozen, out_file)


In [None]:
# 3. Later, load it back exactly as a ScriptModule
loaded = torch.jit.load(out_file, map_location="cpu")
loaded.eval()

with torch.no_grad():
    # returns (B, T_fut, 3)
    out = loaded(static_num,
                   static_cat,
                   hist_num,
                   hist_cat,
                   fut_num,
                   fut_cat)

RecursiveScriptModule(original_name=OneArgWrapper)

In [8]:
tup = (static_num, static_cat, hist_num, hist_cat, fut_num, fut_cat)


In [10]:
with torch.no_grad():
    # returns (B, T_fut, 3)
    out = loaded(tup)


In [11]:
out

tensor([[[-0.7272, -0.7002,  0.1767],
         [-0.7929, -0.7575,  0.0170],
         [-0.8474, -0.7602, -0.0192],
         [-0.7739, -0.7363, -0.1360],
         [-0.8664, -0.7430, -0.0207],
         [-0.5299, -0.5162, -0.1651],
         [-0.4806, -0.3817,  0.0702],
         [-0.5552, -0.4563,  0.0342],
         [-0.5705, -0.4498,  0.0450],
         [-0.8032, -0.7379,  0.0134]]])