In [1]:
import os
os.chdir('/home/kreffert/Probabilistic_LTSF/BasicTS/')
from basicts.metrics import masked_mae, masked_mse, nll_loss, crps, Evaluator, quantile_loss, empirical_crps
from easytorch.device import set_device_type
from easytorch.utils import get_logger, set_visible_devices
# set the device type (CPU, GPU, or MLU)
device_type ='gpu'
gpus = '0'
set_device_type(device_type)
set_visible_devices(gpus)
from easydict import EasyDict
from tqdm import tqdm

def load_cfg(cfg, random_state=None):
    from easytorch.config import init_cfg
    # cfg path which start with dot will crash the easytorch, just remove dot
    while isinstance(cfg, str) and cfg.startswith(('./','.\\')):
        cfg = cfg[2:]
    # while ckpt_path.startswith(('./','.\\')):
    #     ckpt_path = ckpt_path[2:]
    
    # initialize the configuration
    cfg = init_cfg(cfg, save=False)
    # cfg['METRICS'] = EasyDict()
    # all_metrics = [#"MSE", "abs_error", "abs_target_sum", "abs_target_mean",
    #                 # "MAPE", "sMAPE", "MASE", "RMSE", "NRMSE", "ND", "weighted_ND",
    #                 "mean_absolute_QuantileLoss", "CRPS", "MAE_Coverage", "NLL", 
    #                 #"VS", "ES"
    #                 ]
    # cfg['METRICS']['FUNCS'] = EasyDict({
    #     'NLL': nll_loss,
    #     'CRPS': crps,
    #     # 'Evaluator': Evaluator(distribution_type=MODEL_PARAM['distribution_type'], 
    #     #                        quantiles=MODEL_PARAM['quantiles']),
    #     'Val_Evaluator': Evaluator(distribution_type=cfg['MODEL']['PARAM']['distribution_type'], metrics = all_metrics,
    #                             quantiles=cfg['MODEL']['PARAM']['prob_args']['quantiles']),  # only use the evaluator during validation/testing iters
    # })

    
    if random_state is not None:
        print(f'Using random state {random_state}')
        # import os
        # os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        cfg['ENV'] = EasyDict() # Environment settings. Default: None
        # GPU and random seed settings
        cfg['ENV']['TF32'] = True # Whether to use TensorFloat-32 in GPU. Default: False. See https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere.
        cfg['ENV']['SEED'] = random_state # Random seed. Default: None
        cfg['ENV']['DETERMINISTIC'] = True # Whether to set the random seed to get deterministic results. Default: False
        cfg['ENV']['CUDNN'] = EasyDict()
        cfg['ENV']['CUDNN']['ENABLED'] = True # Whether to enable cuDNN. Default: True
        cfg['ENV']['CUDNN']['BENCHMARK'] = True# Whether to enable cuDNN benchmark. Default: True
        cfg['ENV']['CUDNN']['DETERMINISTIC'] = True # Whether to set cuDNN to deterministic mode. Default: False
    return cfg

def load_runner(configs, random_states=[]):
    for rs in random_states:
        for key in configs[rs].keys():
            configs[rs][key]['cfg'] = load_cfg(configs[rs][key]['cfg'], random_state=rs)
            cfg = configs[rs][key]['cfg']
            ckpt_path = '/home/kreffert/Probabilistic_LTSF/BasicTS/' + configs[rs][key]['ckpt']
            strict = True
            runner = cfg['RUNNER'](cfg)
            # setup the graph if needed
            if runner.need_setup_graph:
                runner.setup_graph(cfg=cfg, train=False)
                
            print(f'Loading model checkpoint from {ckpt_path}')
            runner.load_model(ckpt_path=ckpt_path, strict=strict)
            
            # runner.test_pipeline(cfg=cfg, save_metrics=False, save_results=False)
            configs[rs][key]['runner'] = runner
    return configs

import torch

@torch.no_grad()
def get_predictions(configs):
    for rs in configs.keys():
        for key in configs[rs].keys():
            runner = configs[rs][key]['runner']
            cfg = configs[rs][key]['cfg']
            # init test
            runner.test_interval = cfg['TEST'].get('INTERVAL', 1)
            runner.test_data_loader = runner.build_test_data_loader(cfg)
        
            runner.model.eval()
            prediction, target, inputs = [], [], []
        
            for data in tqdm(runner.test_data_loader):
                forward_return = runner.forward(data, epoch=None, iter_num=None, train=False)
                if not runner.if_evaluate_on_gpu:
                    forward_return['prediction'] = forward_return['prediction'].detach().cpu()
                    forward_return['target'] = forward_return['target'].detach().cpu()
                    forward_return['inputs'] = forward_return['inputs'].detach().cpu()
        
                prediction.append(forward_return['prediction'])
                target.append(forward_return['target'])
                inputs.append(forward_return['inputs'])
        
            prediction = torch.cat(prediction, dim=0)
            target = torch.cat(target, dim=0)
            inputs = torch.cat(inputs, dim=0)
        
            returns_all = {'prediction': prediction, 'target': target, 'inputs': inputs}
            configs[rs][key]['returns_all'] = returns_all
    return configs

2025-06-08 13:08:17,233 - easytorch-env - INFO - Use devices 0.


In [2]:
# 1. load the model and set the device
_configs = {'ETTh1_PTST_u': {'cfg':'final_weights/PatchTST/univariate/ETTh1_prob.py',
                           'ckpt': 'final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt'
                          },
           # 'ETTh1_PTST_q': {'cfg': 'final_weights/PatchTST/quantile/ETTh1_prob.py',
           #                  'ckpt': 'final_weights/PatchTST/quantile/ETTh1_100_96_720/a2a39ac1680165e5ffbda2c7bbda5add/PatchTST_best_val_QL.pt'
           #                 }
          }

random_states = range(5)

configs = {rs:_configs for rs in random_states}

configs = load_runner(configs, random_states=random_states)
configs = get_predictions(configs)
# metrics_results = self.compute_evaluation_metrics(returns_all)

2025-06-08 13:08:17,276 - easytorch-env - INFO - Enable TF32 mode
2025-06-08 13:08:17,284 - easytorch-env - INFO - Use deterministic algorithms.
2025-06-08 13:08:17,284 - easytorch-env - INFO - Set cudnn deterministic.
2025-06-08 13:08:17,285 - easytorch - INFO - Set ckpt save dir: '/home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/590c87b95ca683a9092af664b104c9bd'
2025-06-08 13:08:17,286 - easytorch - INFO - Building model.


Using random state 0


2025-06-08 13:08:19,032 - easytorch - INFO - Load model from : /home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt
2025-06-08 13:08:19,036 - easytorch - INFO - Loading Checkpoint from '/home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt'


PatchTST
Loading model checkpoint from /home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt


2025-06-08 13:08:19,485 - easytorch-env - INFO - Enable TF32 mode
2025-06-08 13:08:19,494 - easytorch-env - INFO - Use deterministic algorithms.
2025-06-08 13:08:19,495 - easytorch-env - INFO - Set cudnn deterministic.
2025-06-08 13:08:19,496 - easytorch - INFO - Set ckpt save dir: '/home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/590c87b95ca683a9092af664b104c9bd'
2025-06-08 13:08:19,497 - easytorch - INFO - Building model.


Using random state 1


2025-06-08 13:08:20,641 - easytorch - INFO - Load model from : /home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt
2025-06-08 13:08:20,644 - easytorch - INFO - Loading Checkpoint from '/home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt'


PatchTST
Loading model checkpoint from /home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt


2025-06-08 13:08:21,203 - easytorch-env - INFO - Enable TF32 mode
2025-06-08 13:08:21,210 - easytorch-env - INFO - Use deterministic algorithms.
2025-06-08 13:08:21,211 - easytorch-env - INFO - Set cudnn deterministic.
2025-06-08 13:08:21,211 - easytorch - INFO - Set ckpt save dir: '/home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/590c87b95ca683a9092af664b104c9bd'
2025-06-08 13:08:21,212 - easytorch - INFO - Building model.


Using random state 2


2025-06-08 13:08:22,371 - easytorch - INFO - Load model from : /home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt
2025-06-08 13:08:22,374 - easytorch - INFO - Loading Checkpoint from '/home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt'


PatchTST
Loading model checkpoint from /home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt


2025-06-08 13:08:26,167 - easytorch-env - INFO - Enable TF32 mode
2025-06-08 13:08:26,174 - easytorch-env - INFO - Use deterministic algorithms.
2025-06-08 13:08:26,175 - easytorch-env - INFO - Set cudnn deterministic.
2025-06-08 13:08:26,176 - easytorch - INFO - Set ckpt save dir: '/home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/590c87b95ca683a9092af664b104c9bd'
2025-06-08 13:08:26,176 - easytorch - INFO - Building model.


Using random state 3


2025-06-08 13:08:27,316 - easytorch - INFO - Load model from : /home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt
2025-06-08 13:08:27,319 - easytorch - INFO - Loading Checkpoint from '/home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt'


PatchTST
Loading model checkpoint from /home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt


2025-06-08 13:08:27,838 - easytorch-env - INFO - Enable TF32 mode
2025-06-08 13:08:27,845 - easytorch-env - INFO - Use deterministic algorithms.
2025-06-08 13:08:27,847 - easytorch-env - INFO - Set cudnn deterministic.
2025-06-08 13:08:27,847 - easytorch - INFO - Set ckpt save dir: '/home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/590c87b95ca683a9092af664b104c9bd'
2025-06-08 13:08:27,848 - easytorch - INFO - Building model.


Using random state 4


2025-06-08 13:08:28,976 - easytorch - INFO - Load model from : /home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt
2025-06-08 13:08:28,979 - easytorch - INFO - Loading Checkpoint from '/home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt'


PatchTST
Loading model checkpoint from /home/kreffert/Probabilistic_LTSF/BasicTS/final_weights/PatchTST/univariate/ETTh1_100_96_720/a8de06edad7530010e0b704422b431a2/PatchTST_best_val_NLL.pt


2025-06-08 13:08:29,542 - easytorch - INFO - Test dataset length: 2065
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:08<00:00,  3.83it/s]
2025-06-08 13:08:38,176 - easytorch - INFO - Test dataset length: 2065
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:07<00:00,  4.50it/s]
2025-06-08 13:08:45,540 - easytorch - INFO - Test dataset length: 2065
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:07<00:00,  4.51it/s]
2025-06-08 13:08:52,863 - easytorch - INFO - Test dataset length: 2065
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:07<00:00,  4.49it/s]
2025-06-08 13:09:00,244 - easytorch - INFO - Test dataset length: 2065
100%|████████████████████████████████████████████

In [3]:
import torch
def vs_ensemble_torch(obs, fct, p=1.0):
    """
    Compute Variogram Score using PyTorch on GPU.
    obs: shape (..., D)
    fct: shape (..., M, D)
    """
    M = fct.shape[-2]

    # Compute ensemble variogram component
    fct_diff = fct.unsqueeze(-2) - fct.unsqueeze(-1)  # (B, M, D, D)
    # print(fct_diff.shape)
    vfct = (fct_diff.abs() ** p).sum(dim=-3) / M  # (B, D, D)
    # print(vfct.shape)
    # Compute observed variogram component
    obs_diff = obs.unsqueeze(-2) - obs.unsqueeze(-1)  # (B, D, D)
    vobs = (obs_diff.abs() ** p)  # (B, D, D)
    # print(vobs.shape)
    # print(vfct.shape)
    vs = ((vfct - vobs) ** 2).sum(dim=(-2, -1))  # (B,)
    return vs

def es_ensemble_torch(obs: torch.Tensor, fct: torch.Tensor) -> torch.Tensor:
    """
    Compute the energy score using PyTorch.
    
    Parameters:
    - obs: Tensor of shape (B, D)
    - fct: Tensor of shape (B, M, D)

    Returns:
    - Tensor of shape (B,) with energy scores
    """
    M = fct.shape[-2]

    # E_1: mean norm between forecast samples and the observation
    err_norm = torch.norm(fct - obs.unsqueeze(-2), dim=-1)  # (B, M)
    E_1 = err_norm.sum(dim=-1) / M  # (B,)

    # E_2: mean pairwise distance between forecast samples
    spread = fct.unsqueeze(-3) - fct.unsqueeze(-2)  # (B, M, M, D)
    spread_norm = torch.norm(spread, dim=-1)  # (B, M, M)
    E_2 = spread_norm.sum(dim=(-2, -1)) / (M**2) # (B,)

    return E_1 - 0.5 * E_2  # (B,)

def sample(runner, returns_all, random_state=None):
    from prob.prob_head import ProbabilisticHead # load that class for sampling
    head = ProbabilisticHead(1, 1, runner.distribution_type, prob_args=runner.prob_args)
    samples = []
    batch_size = 64
    num_batches = int(returns_all['prediction'].shape[0]/batch_size)+1
    for b in range(num_batches):
        start, end = b*batch_size, min((b+1)*batch_size, returns_all['prediction'].shape[0])
        pred = returns_all['prediction'][start:end, :, :, :]
        sample = head.sample(pred, num_samples=100, random_state=random_state) # [samples x bs x seq_len x nvars]
        sample = sample.permute(1, 0, 2, 3)       # [bs x samples x seq_len x nvars]
        samples.append(sample)
    samples = torch.cat(samples, dim=0)
    return samples

def evaluate(predictions, returns_all, batch_size=4):
    import scoringrules as sr
    import numpy as np
    device = returns_all['target'].device
    targets = returns_all['target'].squeeze(-1)#.detach().cpu()
    sampless = predictions.permute(0, 2, 3, 1)#.detach().cpu() 
    num_batches = int(returns_all['prediction'].shape[0]/batch_size)+1
    # Lists to accumulate metric values
    crps_list = []
    crps_sum_list = []
    vs_05_list = []
    vs_1_list = []
    vs_2_list = []
    es_list = []
    # Loop through batches
    pbar = tqdm(range(num_batches))
    for b in pbar:
        start, end = b * batch_size, min((b + 1) * batch_size, returns_all['prediction'].shape[0])
        if start == end:
            print("SKipping")
            continue  # Skip empty batch
    
        samples = sampless[start:end, :, :, :]
        target = targets[start:end, :, :]
    
        crps = np.mean(sr.crps_ensemble(target.detach().cpu(), samples.detach().cpu(), estimator='pwm'))
        crps_sum = np.mean(sr.crps_ensemble(target.detach().cpu().sum(axis=-1), samples.detach().cpu().sum(axis=-2), estimator='pwm'))
        # vs_05 = np.mean(sr.variogram_score(target.permute(0, 2, 1), samples.permute(0, 2, 3, 1), p=0.5, backend='numba'))
        vs_05 = torch.mean(vs_ensemble_torch(target.permute(0, 2, 1).to(device), samples.permute(0, 2, 3, 1).to(device), p=0.5))
        # vs_1 = np.mean(sr.variogram_score(target.permute(0, 2, 1), samples.permute(0, 2, 3, 1), p=1.0, backend='numba'))
        vs_1 = torch.mean(vs_ensemble_torch(target.permute(0, 2, 1).to(device), samples.permute(0, 2, 3, 1).to(device), p=1))
        # vs_2 = np.mean(sr.variogram_score(target.permute(0, 2, 1), samples.permute(0, 2, 3, 1), p=2.0, backend='numba'))
        vs_2 = torch.mean(vs_ensemble_torch(target.permute(0, 2, 1).to(device), samples.permute(0, 2, 3, 1).to(device), p=2))
        
        # es = np.mean(sr.energy_score(target.permute(0, 2, 1), samples.permute(0, 2, 3, 1), backend='numba'))
        es = torch.mean(es_ensemble_torch(target.permute(0, 2, 1).to(device), samples.permute(0, 2, 3, 1).to(device)))
        
        # Append to lists
        crps_list.append(crps)
        crps_sum_list.append(crps_sum)
        vs_05_list.append(vs_05.detach().cpu())
        vs_1_list.append(vs_1.detach().cpu())
        vs_2_list.append(vs_2.detach().cpu())
        es_list.append(es.detach().cpu())
    
    # Final averages
    final_scores = {
        "CRPS": np.mean(crps_list),
        "CRPS_Sum": np.mean(crps_sum_list),
        "VS_0.5": np.mean(vs_05_list),
        "VS_1.0": np.mean(vs_1_list),
        "VS_2.0": np.mean(vs_2_list),
        "ES": np.mean(es_list),
    }
    return final_scores

In [4]:
def evaluate_all(configs):
    eval_dict = {rs:{} for rs in configs.keys()}
    for rs in configs.keys():
        for key in configs[rs].keys():
            samples = sample(configs[rs][key]['runner'], configs[rs][key]['returns_all'], random_state=rs)
            eval_dict[rs][key] = evaluate(samples, configs[rs][key]['returns_all'], batch_size=4)
    print(eval_dict)

In [5]:
evaluate_all(configs)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 517/517 [04:07<00:00,  2.09it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 517/517 [04:13<00:00,  2.04it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 517/517 [04:06<00:00,  2.09it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 517/517 [04:06<00:00,  2.10it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 517/517 [04:06<00:00,  2.10it/s]


{0: {'ETTh1_PTST_u': {'CRPS': np.float64(2.4197740580153275), 'CRPS_Sum': np.float64(8.704141144967366), 'VS_0.5': np.float32(572256.2), 'VS_1.0': np.float32(14644100.0), 'VS_2.0': np.float32(1985060100000.0), 'ES': np.float32(83.836754)}}, 1: {'ETTh1_PTST_u': {'CRPS': np.float64(2.4191808755737285), 'CRPS_Sum': np.float64(8.702462262138733), 'VS_0.5': np.float32(572201.0), 'VS_1.0': np.float32(14609265.0), 'VS_2.0': np.float32(28738159000.0), 'ES': np.float32(83.84088)}}, 2: {'ETTh1_PTST_u': {'CRPS': np.float64(2.4195399059040246), 'CRPS_Sum': np.float64(8.704312420450742), 'VS_0.5': np.float32(572196.56), 'VS_1.0': np.float32(14605370.0), 'VS_2.0': np.float32(48451195000.0), 'ES': np.float32(83.866234)}}, 3: {'ETTh1_PTST_u': {'CRPS': np.float64(2.41907135491505), 'CRPS_Sum': np.float64(8.702063031509304), 'VS_0.5': np.float32(572468.4), 'VS_1.0': np.float32(14621595.0), 'VS_2.0': np.float32(48899846000.0), 'ES': np.float32(83.85971)}}, 4: {'ETTh1_PTST_u': {'CRPS': np.float64(2.419596

In [10]:
results = {0: {'ETTh1_PTST_u': {'CRPS': np.float64(2.4197740580153275), 'CRPS_Sum': np.float64(8.704141144967366), 'VS_0.5': np.float32(572256.2), 'VS_1.0': np.float32(14644100.0), 'VS_2.0': np.float32(1985060100000.0), 'ES': np.float32(83.836754)}}, 
           1: {'ETTh1_PTST_u': {'CRPS': np.float64(2.4191808755737285), 'CRPS_Sum': np.float64(8.702462262138733), 'VS_0.5': np.float32(572201.0), 'VS_1.0': np.float32(14609265.0), 'VS_2.0': np.float32(28738159000.0), 'ES': np.float32(83.84088)}}, 
           2: {'ETTh1_PTST_u': {'CRPS': np.float64(2.4195399059040246), 'CRPS_Sum': np.float64(8.704312420450742), 'VS_0.5': np.float32(572196.56), 'VS_1.0': np.float32(14605370.0), 'VS_2.0': np.float32(48451195000.0), 'ES': np.float32(83.866234)}}, 
           3: {'ETTh1_PTST_u': {'CRPS': np.float64(2.41907135491505), 'CRPS_Sum': np.float64(8.702063031509304), 'VS_0.5': np.float32(572468.4), 'VS_1.0': np.float32(14621595.0), 'VS_2.0': np.float32(48899846000.0), 'ES': np.float32(83.85971)}}, 
           4: {'ETTh1_PTST_u': {'CRPS': np.float64(2.4195961071133296), 'CRPS_Sum': np.float64(8.704253091582524), 'VS_0.5': np.float32(572262.1), 'VS_1.0': np.float32(14613185.0), 'VS_2.0': np.float32(138508400000.0), 'ES': np.float32(83.84413)}}}
# Extract metrics
metrics = list(next(iter(results.values()))['ETTh1_PTST_u'].keys())
agg = {metric: [] for metric in metrics}

# rescaling = {
#     "VS_0.5": 1e-4,
#     "VS_1.0": 1e-6,
#     "VS_2.0": 1e-10,
# }
rescaling = {}
for run in results.values():
    for metric in metrics:
        if metric in rescaling.keys():
            agg[metric].append(run['ETTh1_PTST_u'][metric]*rescaling[metric])
        else:
            agg[metric].append(run['ETTh1_PTST_u'][metric])

# Compute stats
summary = {}
for metric in metrics:
    values = np.array(agg[metric], dtype=np.float64)
    summary[metric] = {
        "mean": np.mean(values),
        "std": np.std(values)
    }

# Display results
for metric, stats in summary.items():
    print(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")

CRPS: 2.4194 ± 0.0003
CRPS_Sum: 8.7034 ± 0.0010
VS_0.5: 572276.8500 ± 99.5238
VS_1.0: 14618703.0000 ± 13788.8675
VS_2.0: 449931549081.6000 ± 768507768348.2292
ES: 83.8495 ± 0.0114


In [6]:
import scoringrules as sr
import numpy as np
device = returns_all['target'].device
targets = returns_all['target'].squeeze(-1)#.detach().cpu()
sampless = prediction.permute(0, 2, 3, 1)#.detach().cpu() 
print(sampless.shape)
print(targets.shape)
 # 3. Compute approximate metrics
batch_size = 4
num_batches = int(returns_all['prediction'].shape[0]/batch_size)+1
# Lists to accumulate metric values
crps_list = []
crps_sum_list = []
vs_05_list = []
vs_1_list = []
vs_2_list = []
es_list = []

import torch

def vs_ensemble_torch(obs, fct, p=1.0):
    """
    Compute Variogram Score using PyTorch on GPU.
    obs: shape (..., D)
    fct: shape (..., M, D)
    """
    M = fct.shape[-2]

    # Compute ensemble variogram component
    fct_diff = fct.unsqueeze(-2) - fct.unsqueeze(-1)  # (B, M, D, D)
    # print(fct_diff.shape)
    vfct = (fct_diff.abs() ** p).sum(dim=-3) / M  # (B, D, D)
    # print(vfct.shape)
    # Compute observed variogram component
    obs_diff = obs.unsqueeze(-2) - obs.unsqueeze(-1)  # (B, D, D)
    vobs = (obs_diff.abs() ** p)  # (B, D, D)
    # print(vobs.shape)
    # print(vfct.shape)
    vs = ((vfct - vobs) ** 2).sum(dim=(-2, -1))  # (B,)
    return vs

def es_ensemble_torch(obs: torch.Tensor, fct: torch.Tensor) -> torch.Tensor:
    """
    Compute the energy score using PyTorch.
    
    Parameters:
    - obs: Tensor of shape (B, D)
    - fct: Tensor of shape (B, M, D)

    Returns:
    - Tensor of shape (B,) with energy scores
    """
    M = fct.shape[-2]

    # E_1: mean norm between forecast samples and the observation
    err_norm = torch.norm(fct - obs.unsqueeze(-2), dim=-1)  # (B, M)
    E_1 = err_norm.sum(dim=-1) / M  # (B,)

    # E_2: mean pairwise distance between forecast samples
    spread = fct.unsqueeze(-3) - fct.unsqueeze(-2)  # (B, M, M, D)
    spread_norm = torch.norm(spread, dim=-1)  # (B, M, M)
    E_2 = spread_norm.sum(dim=(-2, -1)) / (M**2) # (B,)

    return E_1 - 0.5 * E_2  # (B,)

# Loop through batches
pbar = tqdm(range(num_batches))
for b in pbar:
    start, end = b * batch_size, min((b + 1) * batch_size, returns_all['prediction'].shape[0])
    if start == end:
        print("SKipping")
        continue  # Skip empty batch

    samples = sampless[start:end, :, :, :]
    target = targets[start:end, :, :]

    crps = np.mean(sr.crps_ensemble(target.detach().cpu(), samples.detach().cpu(), estimator='pwm'))
    crps_sum = np.mean(sr.crps_ensemble(target.detach().cpu().sum(axis=-1), samples.detach().cpu().sum(axis=-2), estimator='pwm'))
    # vs_05 = np.mean(sr.variogram_score(target.permute(0, 2, 1), samples.permute(0, 2, 3, 1), p=0.5, backend='numba'))
    vs_05 = torch.mean(vs_ensemble_torch(target.permute(0, 2, 1).to(device), samples.permute(0, 2, 3, 1).to(device), p=0.5))
    # vs_1 = np.mean(sr.variogram_score(target.permute(0, 2, 1), samples.permute(0, 2, 3, 1), p=1.0, backend='numba'))
    vs_1 = torch.mean(vs_ensemble_torch(target.permute(0, 2, 1).to(device), samples.permute(0, 2, 3, 1).to(device), p=1))
    # vs_2 = np.mean(sr.variogram_score(target.permute(0, 2, 1), samples.permute(0, 2, 3, 1), p=2.0, backend='numba'))
    vs_2 = torch.mean(vs_ensemble_torch(target.permute(0, 2, 1).to(device), samples.permute(0, 2, 3, 1).to(device), p=2))
    
    # es = np.mean(sr.energy_score(target.permute(0, 2, 1), samples.permute(0, 2, 3, 1), backend='numba'))
    es = torch.mean(es_ensemble_torch(target.permute(0, 2, 1).to(device), samples.permute(0, 2, 3, 1).to(device)))
    
    # Append to lists
    crps_list.append(crps)
    crps_sum_list.append(crps_sum)
    vs_05_list.append(vs_05.detach().cpu())
    vs_1_list.append(vs_1.detach().cpu())
    vs_2_list.append(vs_2.detach().cpu())
    es_list.append(es.detach().cpu())

    # # Update tqdm with running averages
    # # pbar.set_description(f"CRPS: {np.mean(crps_list):.4f}, VS1: {np.mean(vs_1_list):.4f}, ES: {np.mean(es_list):.4f}")
    
    # pbar.set_description(f"VS: {vs_05:.4f}, VS_T: {vs_05_torch:.4f},")

# Final averages
final_scores = {
    "CRPS": np.mean(crps_list),
    "CRPS_Sum": np.mean(crps_sum_list),
    "VS_0.5": np.mean(vs_05_list),
    "VS_1.0": np.mean(vs_1_list),
    "VS_2.0": np.mean(vs_2_list),
    "ES": np.mean(es_list),
}

print("\nFinal Scores:")
for k, v in final_scores.items():
    print(f"{k}: {v:.4f}")

NameError: name 'returns_all' is not defined

In [None]:
Final Scores:
CRPS: 2.4193
CRPS_Sum: 8.7032
VS_0.5: 572217.8750
VS_1.0: 14610408.0000
VS_2.0: 345045532672.0000
ES: 83.8458


Final Scores:
CRPS: 2.4194
CRPS_Sum: 8.7040
VS_0.5: 572181.3125
VS_1.0: 14610432.0000
VS_2.0: 168409481216.0000
ES: 83.8460