In [1]:
import torch
from transformers import TimeSeriesTransformerForPrediction
from transformers import TimeSeriesTransformerConfig

In [2]:
# Initializing a Time Series Transformer configuration with 10 time steps for prediction
configuration = TimeSeriesTransformerConfig(prediction_length=10,
                                           context_length=30,
                                           distribution_output='student_t',
                                           input_size=4,
                                           loss = 'nll',
                                           
                                           lags_sequence=[1,2,3,4],
                                           num_time_features=3,
                                           cardinality=None
                                            )

# Randomly initializing a model (with random weights) from the configuration
model = TimeSeriesTransformerForPrediction(configuration)

# Accessing the model configuration
configuration = model.config
configuration

TimeSeriesTransformerConfig {
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "attention_dropout": 0.1,
  "cardinality": [
    0
  ],
  "context_length": 30,
  "d_model": 64,
  "decoder_attention_heads": 2,
  "decoder_ffn_dim": 32,
  "decoder_layerdrop": 0.1,
  "decoder_layers": 2,
  "distribution_output": "student_t",
  "dropout": 0.1,
  "embedding_dimension": [
    0
  ],
  "encoder_attention_heads": 2,
  "encoder_ffn_dim": 32,
  "encoder_layerdrop": 0.1,
  "encoder_layers": 2,
  "feature_size": 27,
  "init_std": 0.02,
  "input_size": 4,
  "is_encoder_decoder": true,
  "lags_sequence": [
    1,
    2,
    3,
    4
  ],
  "loss": "nll",
  "model_type": "time_series_transformer",
  "num_dynamic_real_features": 0,
  "num_parallel_samples": 100,
  "num_static_categorical_features": 0,
  "num_static_real_features": 0,
  "num_time_features": 3,
  "prediction_length": 10,
  "scaling": "mean",
  "transformers_version": "4.45.2",
  "use_cache": true
}

In [3]:
best_model_path = 'best_model.pth'
model.load_state_dict(torch.load(best_model_path))

<All keys matched successfully>

In [4]:
past_time_features_test = torch.load('test_past_time_features_tensor.pt')
future_time_features_test=torch.load('test_future_time_features_tensor.pt')
past_values_tensors_test=torch.load('test_past_value_tensor.pt')
future_values_tensors_test=torch.load('test_future_value_tensor.pt')

In [5]:
past_time_features_test.shape,future_time_features_test.shape,past_values_tensors_test.shape

(torch.Size([38949, 34, 3]),
 torch.Size([38949, 10, 3]),
 torch.Size([38949, 34, 4]))

In [6]:
# Generate a past_observed_mask tensor with shape [42844, 30,4], all elements equal to 1
past_observed_mask_tensor = torch.ones((38949, 34,4), dtype=torch.float)

# Check the shape
past_observed_mask_tensor.shape

torch.Size([38949, 34, 4])

In [7]:
from torch.utils.data import DataLoader, TensorDataset
batch_size = 1000

# Create a DataLoader to handle mini-batches
test_dataset = TensorDataset(past_values_tensors_test, past_time_features_test, future_values_tensors_test, future_time_features_test,past_observed_mask_tensor)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [8]:
# Total number of batches
total_batches = len(test_loader)

# Generate predictions in mini-batches
all_mean_predictions = []

print("Starting prediction...")

for batch_idx, batch in enumerate(test_loader):
    # Unpack the batch
    batch_past_values, batch_past_time_features, batch_future_values, batch_future_time_features, batch_past_observed_mask = batch
    
    # Model prediction
    outputs = model.generate(
        past_values=batch_past_values,
        past_time_features=batch_past_time_features,
        future_time_features=batch_future_time_features,
        past_observed_mask=batch_past_observed_mask
    )
    
    # Compute the mean prediction
    mean_prediction = outputs.sequences.mean(dim=1)
    
    # Collect predictions
    all_mean_predictions.append(mean_prediction)
    
    # Print progress for every 10th batch
    print(f"Processed batch {batch_idx+1}/{total_batches} ({(batch_idx+1)/total_batches*100:.2f}%)")

print("Prediction complete!")


Starting prediction...
Processed batch 1/39 (2.56%)
Processed batch 2/39 (5.13%)
Processed batch 3/39 (7.69%)
Processed batch 4/39 (10.26%)
Processed batch 5/39 (12.82%)
Processed batch 6/39 (15.38%)
Processed batch 7/39 (17.95%)
Processed batch 8/39 (20.51%)
Processed batch 9/39 (23.08%)
Processed batch 10/39 (25.64%)
Processed batch 11/39 (28.21%)
Processed batch 12/39 (30.77%)
Processed batch 13/39 (33.33%)
Processed batch 14/39 (35.90%)
Processed batch 15/39 (38.46%)
Processed batch 16/39 (41.03%)
Processed batch 17/39 (43.59%)
Processed batch 18/39 (46.15%)
Processed batch 19/39 (48.72%)
Processed batch 20/39 (51.28%)
Processed batch 21/39 (53.85%)
Processed batch 22/39 (56.41%)
Processed batch 23/39 (58.97%)
Processed batch 24/39 (61.54%)
Processed batch 25/39 (64.10%)
Processed batch 26/39 (66.67%)
Processed batch 27/39 (69.23%)
Processed batch 28/39 (71.79%)
Processed batch 29/39 (74.36%)
Processed batch 30/39 (76.92%)
Processed batch 31/39 (79.49%)
Processed batch 32/39 (82.05

In [9]:
all_mean_predictions = torch.cat(all_mean_predictions, dim=0)

In [13]:
all_mean_predictions

tensor([[[ 1.7086e+00,  1.6352e+01,  1.6659e+01, -2.2436e-01],
         [ 1.1016e+01,  6.7102e+01,  4.6715e+01,  1.8072e+00],
         [ 3.7967e+01,  3.0463e+01, -1.0439e+01,  1.2513e+00],
         ...,
         [ 4.6800e+01, -3.6752e+01, -3.4721e-01,  2.0966e+00],
         [ 5.6351e+01, -1.8228e+01, -4.8661e+01, -2.0502e-01],
         [ 5.1710e+01, -1.9051e+01, -2.1022e+01,  6.4342e-01]],

        [[-1.4895e+00,  1.5356e+01,  2.6123e+01,  2.3519e+00],
         [ 1.6966e+01,  7.4463e+00, -2.7842e+00,  1.1521e-01],
         [ 1.1746e+01,  1.1001e+01, -4.5426e+00, -2.3459e+00],
         ...,
         [ 1.5824e+00, -4.6295e+00,  1.3863e+01, -1.9528e+00],
         [ 1.8406e+01,  1.2853e+01,  3.2364e+01, -1.7912e+00],
         [ 6.6342e+00,  1.0064e+01, -2.5500e+00, -4.3675e+00]],

        [[ 1.0777e+01,  2.9949e+00, -1.2926e+01, -3.3928e-01],
         [-1.3940e+01,  3.5890e+01,  5.0964e+01,  3.7098e+00],
         [ 4.9544e+00,  9.4251e+00,  1.1573e+01,  1.7193e+00],
         ...,
         

In [12]:
torch.save(all_mean_predictions, 'prediction_value.pt')