In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from binconvfm.models.binconv import BinConvForecaster
import numpy as np
import matplotlib.pyplot as plt
from tqdm import trange

In [None]:
class Seq2SeqDataset(Dataset):
    def __init__(self, seq, input_len=24, output_len=5):
        self.input_len = input_len
        self.output_len = output_len
        self.seq = torch.tensor(seq, dtype=torch.float32)
        self.seq = self.seq.unsqueeze(-1)
        self.length = len(seq) - input_len - output_len + 1

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        input_seq = self.seq[idx: idx + self.input_len]
        target_seq = self.seq[idx + self.input_len: idx + self.input_len + self.output_len]
        return input_seq, target_seq

In [None]:
x_space = np.linspace(0, 100, 1000)
seq = np.sin(x_space) + np.random.randn(1000) * 0.1

In [None]:
plt.figure(figsize=(10, 3))
plt.plot(seq)
plt.show()

In [None]:
input_len = 100
output_len = 50
batch_size = 256
n_samples = 10
train_ds = Seq2SeqDataset(seq, input_len, 1)  # for binconv, like for decoder-only models output_len = 1
test_ds = Seq2SeqDataset(seq, input_len, output_len)
train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
pred_dataloader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

In [None]:
data = next(iter(train_dataloader))

In [None]:
data[1].shape

In [None]:
model = BinConvForecaster(num_epochs=5, n_samples=1, context_length=input_len, num_filters_2d=input_len,
                          num_filters_1d=input_len, num_bins=256, min_bin_value=-1.5, max_bin_value=1.5, num_blocks=2)
model.fit(train_dataloader, val_dataloader)

In [10]:
model.evaluate(test_dataloader)

torch.Size([256, 256])
probs torch.Size([256, 256])
pred torch.Size([256, 1, 1, 256])
next_input torch.Size([256, 1, 1, 256])
current_context torch.Size([256, 100, 1, 256])
torch.Size([256, 100, 1, 256])
forward single
torch.Size([256, 100, 256])
torch.Size([256, 256])
probs torch.Size([256, 256])
pred torch.Size([256, 1, 1, 256])
next_input torch.Size([256, 1, 1, 256])
current_context torch.Size([256, 100, 1, 256])
torch.Size([256, 100, 1, 256])
forward single
torch.Size([256, 100, 256])
torch.Size([256, 256])
probs torch.Size([256, 256])
pred torch.Size([256, 1, 1, 256])
next_input torch.Size([256, 1, 1, 256])
current_context torch.Size([256, 100, 1, 256])
torch.Size([256, 100, 1, 256])
forward single
torch.Size([256, 100, 256])
torch.Size([256, 256])
probs torch.Size([256, 256])
pred torch.Size([256, 1, 1, 256])
next_input torch.Size([256, 1, 1, 256])
current_context torch.Size([256, 100, 1, 256])
torch.Size([256, 100, 1, 256])
forward single
torch.Size([256, 100, 256])
torch.Size([

{'mase': 2.2905285358428955, 'crps': 0.1375533938407898}

In [None]:
pred = model.predict(pred_dataloader, horizon=output_len)
pred[0].shape

/Users/andreichernov/Documents/Personal/research/foundation TS/binconvfm/venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

current_context torch.Size([256, 100, 1, 256])
torch.Size([256, 100, 1, 256])
forward single
torch.Size([256, 100, 256])
torch.Size([256, 256])
probs torch.Size([256, 256])
pred torch.Size([256, 1, 1, 256])
next_input torch.Size([256, 1, 1, 256])
current_context torch.Size([256, 100, 1, 256])
torch.Size([256, 100, 1, 256])
forward single
torch.Size([256, 100, 256])
torch.Size([256, 256])
probs torch.Size([256, 256])
pred torch.Size([256, 1, 1, 256])
next_input torch.Size([256, 1, 1, 256])
current_context torch.Size([256, 100, 1, 256])
torch.Size([256, 100, 1, 256])
forward single
torch.Size([256, 100, 256])
torch.Size([256, 256])
probs torch.Size([256, 256])
pred torch.Size([256, 1, 1, 256])
next_input torch.Size([256, 1, 1, 256])
current_context torch.Size([256, 100, 1, 256])
torch.Size([256, 100, 1, 256])
forward single
torch.Size([256, 100, 256])
torch.Size([256, 256])
probs torch.Size([256, 256])
pred torch.Size([256, 1, 1, 256])
next_input torch.Size([256, 1, 1, 256])
current_cont

In [None]:
idx0 = torch.randint(len(pred), (1,)).item()
idx1 = torch.randint(len(pred[idx0]), (1,)).item()
input_seq, target_seq = list(pred_dataloader)[idx0]
input_seq, target_seq = input_seq[idx1, :, -1], target_seq[idx1, :, -1]
q = torch.tensor([0.01, 0.5, 0.99])
pred_seq = torch.quantile(pred[idx0][idx1, :, :, -1], q=q, dim=0)

plt.figure(figsize=(10, 3))
plt.plot(range(input_len + output_len), torch.concat([input_seq, target_seq]))
plt.plot(range(input_len, input_len + output_len), pred_seq[1])
plt.fill_between(range(input_len, input_len + output_len), pred_seq[0], pred_seq[2], alpha=0.5, color='tab:orange')
plt.show()