In [85]:
from probts.data.data_utils.data_scaler import Scaler, StandardScaler

from probts.model.forecaster import LinearForecaster, NaiveForecaster
from probts.model.forecast_module import ProbTSForecastModule
from probts.data import ProbTSDataModule, DataManager, ProbTSBatchData
from probts.utils import find_best_epoch
from lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
import torch
import matplotlib.pyplot as plt

In [451]:
class BinaryQuantizer(Scaler):
    def __init__(self, num_bins=1000, min_val=-5.0, max_val=5.0):
        self.num_bins = num_bins
        self.min_val = min_val
        self.max_val = max_val
        self.bin_values_ = torch.linspace(self.min_val, self.max_val, self.num_bins)

    def fit(self, values):
        self.min_val = values.min()
        self.max_val = values.max()
        self.bin_values_ = torch.linspace(self.min_val, self.max_val, self.num_bins)

    def fit_transform(self, values):
        self.fit(values)
        return self.transform(values)

    def transform(self, values):
        bin_thresholds = self.bin_values_.reshape(1, 1, -1)
        return (values >= bin_thresholds).float()

    def inverse_transform(self, values):
        reversed_bin = torch.flip(values, dims=(-1,))
        idx_first_one_reversed = reversed_bin.argmax(axis=-1)[..., None]
        idx_last_one = self.num_bins - 1 - idx_first_one_reversed
        reconstructed = self.bin_values_[idx_last_one]
        return reconstructed

In [487]:
class StandardBinScaler(Scaler):
    def __init__(self, standard: StandardScaler, bin: BinaryQuantizer):
        self.standard = standard
        self.bin = bin

    def fit(self, X):
        Z = self.standard.fit_transform(X)
        self.bin.fit(Z)

    def transform(self, X):
        Z = self.standard.transform(X)
        print(Z)
        return self.bin.transform(Z)

    def fit_transform(self, X):
        self.fit(X)
        return self.transform(X)

    def inverse_transform(self, X):
        Z = self.bin.inverse_transform(X)
        return self.standard.inverse_transform(Z)

In [488]:
# data_manager = DataManager(
#     dataset='tourism_monthly',
#     path='../datasets',
#     context_length=12,
#     prediction_length=12,
# )
# data_manager.context_length

In [489]:
class CustomDataManager(DataManager):
    def _configure_scaler(self, scaler_type: str):
        """Configure the scaler."""
        if scaler_type == "standard":
            return StandardScaler(var_specific=self.var_specific_norm)
        elif scaler_type == "temporal":
            return TemporalScaler()
        elif scaler_type == "binary":
            return BinaryQuantizer()
        elif scaler_type == "standard_binary":
            return StandardBinScaler(StandardScaler(var_specific=self.var_specific_norm), BinaryQuantizer())
        return IdentityScaler()

In [490]:
# data_module = ProbTSDataModule(
#     data_manager=data_manager,
#     batch_size=32,
#     test_batch_size=32,
#     num_workers=8,
# )
# test_dataloader = data_module.test_dataloader()
# train_dataloader = data_module.train_dataloader()
# val_dataloader = data_module.val_dataloader()

In [491]:
# for test_batch in test_dataloader:
#     break

In [492]:
# batch_data = ProbTSBatchData(test_batch, 'cpu')
# batch_data.past_target_cdf.shape

In [493]:
# plt.figure(figsize=(10,3))
# plt.plot(batch_data.past_target_cdf[13, :, 0].t())
# plt.show()

In [494]:
# scaler = StandardBinScaler(StandardScaler(), BinaryQuantizer())
# scaler.fit(batch_data.past_target_cdf)
# transformed = scaler.transform(batch_data.past_target_cdf)
# transformed.shape

In [495]:
# plt.figure(figsize=(10,3))
# plt.imshow(transformed[13].T, aspect='auto', interpolation='none', cmap='Reds')
# plt.show()

In [496]:
# reconstructed = scaler.inverse_transform(transformed)
# reconstructed.shape

In [497]:
# plt.figure(figsize=(10,3))
# plt.plot(reconstructed[13, :, 0].t())
# plt.show()


In [498]:
data_manager = CustomDataManager(
    dataset='tourism_monthly',
    path='../datasets',
    context_length=72,
    prediction_length=24,
    scaler="standard_binary",
)

# data_manager = DataManager(
#     dataset='m4_daily',
#     # dataset='etth1',
#     path='./datasets',
#     context_length=12,
#     prediction_length=12,
#     scaler="standard_binary",
# )

Loading Short-term Dataset: tourism_monthly


Download tourism_monthly_dataset.zip:: 200kB [00:00, 369kB/s]
creating json files: 100%|██████████| 366/366 [00:00<00:00, 443803.20it/s]

No validation set is used.





In [499]:
data_manager.context_length

72

In [500]:
data_module = ProbTSDataModule(
    data_manager=data_manager,
    batch_size=1,
    test_batch_size=1,
    num_workers=8,
)
test_dataloader = data_module.test_dataloader()
train_dataloader = data_module.train_dataloader()
val_dataloader = data_module.val_dataloader()

In [501]:
for test_batch in test_dataloader:
    break

In [502]:
test_batch['past_target_cdf'].shape

torch.Size([1, 84])

In [503]:
test_batch['past_target_cdf'].reshape(-1,1).shape

torch.Size([84, 1])

In [506]:
data_manager.scaler.standard.mean

In [504]:
data_manager.scaler.transform(test_batch['past_target_cdf'].reshape(-1,1))

tensor([[5466.7803],
        [3235.1677],
        [2157.9800],
        [1379.7252],
        [1728.0400],
        [1350.1099],
        [1216.0149],
        [1751.3252],
        [1805.3201],
        [2570.0249],
        [3204.2402],
        [5395.7202],
        [6078.8286],
        [3587.0984],
        [2285.1951],
        [1582.1899],
        [1787.4298],
        [1554.8701],
        [1409.8649],
        [1612.1250],
        [2286.2400],
        [2913.7551],
        [3645.9084],
        [5956.7085],
        [6326.9751],
        [3914.6602],
        [2617.6750],
        [1675.1650],
        [2139.2200],
        [1715.4899],
        [1663.5800],
        [2053.7000],
        [2354.9299],
        [3038.5918],
        [3470.6094],
        [6606.1836],
        [6587.6367],
        [4133.7827],
        [2960.0244],
        [1762.5850],
        [2125.6401],
        [1815.9150],
        [1632.3149],
        [2210.3950],
        [2210.2151],
        [3099.2693],
        [3468.7778],
        [6482

tensor([[[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]]])

In [283]:
# for train_batch in train_dataloader:
#     break

In [284]:
# batch_data = ProbTSBatchData(test_batch, 'cpu')
# batch_data.past_target_cdf.shape

In [285]:
data_manager.context_length

72

In [507]:
def sliding_window_batch(x, L, H):
    """
    x: Tensor of shape (B, L+H, C)
    Returns: Tensor of shape (B, H, L, C)
    """
    B, total_len, C = x.shape
    assert total_len >= L + H, "Not enough sequence length for given L and H"

    windows = [x[:, h:h + L, :].unsqueeze(1) for h in range(H)]  # list of (B, 1, L, C)
    return torch.cat(windows, dim=1)  # (B, H, L, C)

In [508]:
from probts.model.forecaster import Forecaster
from torch import nn
import torch.nn.functional as F


class BinConv(Forecaster):
    def __init__(self, context_length: int, num_bins: int, kernel_size_across_bins_2d: int = 3,
                 kernel_size_across_bins_1d: int = 3, num_filters_2d: int = 8,
                 num_filters_1d: int = 32, is_cum_sum: bool = False, **kwargs) -> None:
        """
        Initialize the model with parameters.
        """
        super().__init__(context_length=context_length, **kwargs)
        # Initialize model parameters here
        self.context_length = context_length
        self.num_bins = num_bins
        self.num_filters_2d = num_filters_2d
        self.num_filters_1d = num_filters_1d
        self.kernel_size_across_bins_2d = kernel_size_across_bins_2d
        self.kernel_size_across_bins_1d = kernel_size_across_bins_1d
        self.is_cum_sum = is_cum_sum

        # Conv2d over (context_length, num_bins)
        self.conv = nn.Conv2d(
            in_channels=1,
            out_channels=self.num_filters_2d,
            kernel_size=(context_length, kernel_size_across_bins_2d),
            bias=True
        )

        self.conv1d_1 = nn.Conv1d(
            in_channels=self.num_filters_2d,
            out_channels=self.num_filters_1d,
            kernel_size=kernel_size_across_bins_1d,
            bias=True
        )

        self.conv1d_2 = nn.Conv1d(
            in_channels=self.num_filters_1d,
            out_channels=self.num_bins,
            kernel_size=kernel_size_across_bins_1d,
            bias=True
        )

    # def forward(self, inputs):
    #     """
    #     Forward pass for the model.
    #
    #     Parameters:
    #     inputs [Tensor]: Input tensor for the model.
    #
    #     Returns:
    #     Tensor: Output tensor.
    #     """
    #     # Perform the forward pass of the model
    #     return outputs

    def forward(self, x):
        def pad_channels(tensor, pad_size: int, pad_val_left=1.0, pad_val_right=0.0):
            if pad_size == 0:
                return tensor
            left = torch.full((*tensor.shape[:-1], pad_size), pad_val_left, device=tensor.device)
            right = torch.full((*tensor.shape[:-1], pad_size), pad_val_right, device=tensor.device)
            return torch.cat([left, tensor, right], dim=-1)

        x = x.float()
        # x: (batch_size, context_length, num_bins)
        batch_size, context_length, num_bins = x.shape
        assert context_length == self.context_length, "Mismatch in context length"

        pad2d = self.kernel_size_across_bins_2d // 2 if self.kernel_size_across_bins_2d > 1 else 0
        x_padded = pad_channels(x, pad2d)
        x_conv_in = x_padded.unsqueeze(1)
        conv_out = F.relu(self.conv(x_conv_in).squeeze(2))  # (batch_size, num_filters_2d, num_bins)

        pad1d = self.kernel_size_across_bins_1d // 2 if self.kernel_size_across_bins_1d > 1 else 0
        h_padded = pad_channels(conv_out, pad1d)
        h = F.relu(self.conv1d_1(h_padded))

        h_padded = pad_channels(h, pad1d)
        out = self.conv1d_2(h_padded).mean(dim=1)  # (batch_size, num_bins)

        if self.is_cum_sum:
            out = torch.flip(torch.cumsum(torch.flip(out, dims=[1]), dim=1), dims=[1])
        return out

    def loss(self, batch_data):
        """
        Compute the loss for the given batch data.

        Parameters:
        batch_data [dict]: Dictionary containing input data and possibly target data.

        Returns:
        Tensor: Computed loss.
        """
        # Extract inputs and targets from batch_data

        inputs = self.get_inputs(batch_data, 'all')
        # print(f'bool:{torch.allclose(inputs[:, -self.prediction_length:, :], batch_data.future_target_cdf.float())}')
        inputs = sliding_window_batch(inputs, self.context_length, self.prediction_length).float()
        outputs = self(inputs.view(-1, *inputs.shape[2:]))
        # outputs = outputs[:, -self.prediction_length-1:-1, ...]
        target = batch_data.future_target_cdf.float()
        loss = F.binary_cross_entropy_with_logits(input=outputs, target=target.view(-1, *target.shape[2:]),)
        print(loss)
        return loss

    def forecast(self, batch_data, num_samples=None):
        inputs = self.get_inputs(batch_data, 'encode')
        current_context = inputs.clone()
        forecasts = []
        for _ in range(self.prediction_length):
            pred = F.sigmoid(self(current_context))  # (B, D)
            pred = (pred >= 0.5).int()
            forecasts.append(pred.unsqueeze(1))  # (B, 1, D)
            next_input = pred.unsqueeze(1)
            current_context = torch.cat([current_context[:, 1:], next_input], dim=1)

        return torch.cat(forecasts, dim=1)  # (B, T, D)

    # def forecast(self, batch_data, num_samples=None):
    #     """
    #     Generate forecasts for the given batch data.
    #
    #     Parameters:
    #     batch_data [dict]: Dictionary containing input data.
    #     num_samples [int, optional]: Number of samples per distribution during evaluation. Defaults to None.
    #
    #     Returns:
    #     Tensor: Forecasted outputs.
    #     """
    #     # Perform the forward pass to get the outputs
    #     outputs = self(batch_data.past_target_cdf[:, -self.context_length:, :])
    #     print(f'num samples:{num_samples}')
    #     if num_samples is not None:
    #         # If num_samples is specified, use it to sample from the distribution
    #         outputs = self.sample_from_distribution(outputs, num_samples)
    #     else:
    #         # If perform point estimation, the num_samples is equal to 1
    #         outputs = outputs.unsqueeze(1)
    #     return outputs  # [batch_size, num_samples, prediction_length, var_num]

In [509]:
data_manager.context_length

72

In [510]:
forecaster = BinConv(
    num_bins=1000,
    kernel_size_across_bins_2d=1,
    kernel_size_across_bins_1d=5,
    num_filters_2d=8,
    num_filters_1d=32,
    individual=True,
    use_lags=False,
    use_feat_idx_emb=False,
    use_time_feat=False,
    target_dim=data_manager.target_dim,
    context_length=data_manager.context_length,
    prediction_length=data_manager.prediction_length,
    freq=data_manager.freq,
    lags_list=data_manager.lags_list,
    time_feat_dim=data_manager.time_feat_dim,
    dataset=data_manager.dataset,
)
model = ProbTSForecastModule(
    forecaster=forecaster,
    scaler=data_manager.scaler,
    learning_rate=0.005,
    quantiles_num=20,
    num_samples=None
)

sampling_weight_scheme: none


/Users/andreichernov/miniforge3/envs/probts/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'forecaster' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['forecaster'])`.


In [511]:
trainer = Trainer(
    accelerator="cpu",
    devices=1,
    strategy="auto",
    max_epochs=1,
    use_distributed_sampler=False,
    limit_train_batches=100,
    log_every_n_steps=1,
    accumulate_grad_batches=8,
    default_root_dir='./results',
    logger=CSVLogger('./logs'),
)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/andreichernov/miniforge3/envs/probts/lib/python3.10/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


In [512]:
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)


  | Name       | Type    | Params | Mode 
-----------------------------------------------
0 | forecaster | BinConv | 162 K  | train
-----------------------------------------------
162 K     Trainable params
0         Non-trainable params
162 K     Total params
0.652     Total estimated model params size (MB)
4         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/andreichernov/miniforge3/envs/probts/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/andreichernov/miniforge3/envs/probts/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:106: Total length of `DataLoader` across ranks is zero. Please make sure this was your intention.


Training: |          | 0/? [00:00<?, ?it/s]

tensor([[[ 71498.9531],
         [ 79187.1094],
         [101896.1016],
         [115971.7969],
         [ 94962.1484],
         [ 80648.3281],
         [ 64196.0781],
         [ 50364.8594],
         [ 57624.0586],
         [ 47163.8711],
         [ 48874.0703],
         [ 62737.6094],
         [ 69621.1328],
         [ 71454.2188],
         [107916.7969],
         [120461.5000],
         [ 99441.1797],
         [ 84936.5391],
         [ 62809.5195],
         [ 54028.4805],
         [ 58605.9102],
         [ 50516.3398],
         [ 55711.5391],
         [ 55798.4102],
         [ 65033.1797],
         [ 89421.1406],
         [119027.8984],
         [133411.2969],
         [112890.7031],
         [ 96718.1406],
         [ 76462.7969],
         [ 57951.6797],
         [ 62094.6914],
         [ 55118.2305],
         [ 66128.3516],
         [ 71334.2578],
         [ 75644.7188],
         [ 98380.4297],
         [127255.0000],
         [146442.7031],
         [121934.7969],
         [ 88537

KeyboardInterrupt: 

In [435]:
for test_batch in test_dataloader:
    break

In [436]:
test_batch['past_target_cdf'].shape


torch.Size([1, 84])

In [443]:
batch_data = ProbTSBatchData(test_batch, model.device)
past_target_cdf = model.scaler.transform(batch_data.past_target_cdf)
future_target_cdf = model.scaler.transform(batch_data.future_target_cdf)
batch_data.past_target_cdf = past_target_cdf

batch_idx = 0
with torch.no_grad():
    prediction = model.forecaster.forecast(batch_data)



In [450]:
batch_data

<probts.data.data_wrapper.ProbTSBatchData at 0x2a532a5c0>

In [442]:
batch_data.past_target_cdf.shape

torch.Size([1, 84, 1000])

In [427]:
batch_data.past_target_cdf.shape

torch.Size([1, 84, 1000])

In [None]:
context_length = data_manager.context_length
prediction_length = data_manager.prediction_length
past_range = range(0, context_length)
future_range = range(context_length, context_length + prediction_length)
full_range = range(0, context_length + prediction_length)

for i in range(min(10, forecaster.target_dim)):
    target = torch.cat([past_target_cdf[batch_idx, -context_length:, i], future_target_cdf[batch_idx, :, i]])
    plt.figure(figsize=(10, 2))
    plt.plot(full_range, target)
    plt.plot(future_range, prediction[:, i])