# Create a dataset

In [27]:
import sys
import os
import torch
from torch.utils.data import IterableDataset

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..")))

from data.preprocessing.transform.dataset_builder import Builder
from data.preprocessing.transform.probabilistic_mixing_dataset import ProbabilisticMixingDataset
from data.preprocessing.downloader.gift_eval import load_gifteval_dataset_wrapper

In [28]:
class PostProcessingDataset(IterableDataset):
    def __init__(self, gift_eval_datasets, window_size=32, prediction_depth=1, seed=42, batch_size=64):
        super().__init__()

        dataset_dict = {
            str(i): (
                Builder(ds)
                .sliding_window(window_size + prediction_depth, step=window_size + prediction_depth)
                .map(lambda t: (t[:window_size], t[window_size:]))
                .build()
            )
            for i, ds in enumerate(gift_eval_datasets)
        }

        mixed_dataset = ProbabilisticMixingDataset(dataset_dict, seed=seed)

        def collate_list_of_tuples(data):
            features, targets = zip(*data)
            features = torch.stack(features)
            targets = torch.stack(targets)
            return features, targets

        self.ds = (
            Builder(mixed_dataset)
                .batch(batch_size)
                .map(collate_list_of_tuples)
                .build()
        )
        
    def __iter__(self):
        for item in iter(self.ds):
            yield item

In [29]:
ds = PostProcessingDataset([
    load_gifteval_dataset_wrapper(['bdg-2_fox']),
    load_gifteval_dataset_wrapper(['bdg-2_bear'])
])

# Define a model

In [30]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pytorch_lightning as pl
from torch import nn

class LinearRegressionModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(32, 1)

    def forward(self, x):
        # x: [batch_size, 32, 1] => squeeze last dim to [batch_size, 32]
        x = x.squeeze(-1)
        return self.linear(x)  # Output shape: [batch_size, 1]

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.float()
        y = y.float()
        y = y.squeeze(-1)  # shape: [batch_size, 1]
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Train a model

In [35]:
dataloader = DataLoader(ds, shuffle=False, batch_size=None)

model = LinearRegressionModel()
trainer = pl.Trainer(max_epochs=1, accelerator="auto")
trainer.fit(model, dataloader)

# Test that a model predicts something

In [33]:
x, y = next(iter(test_loader))

In [37]:
test_loader = DataLoader(ds, shuffle=False, batch_size=None)

model.eval()

for x, y in test_loader:
    with torch.no_grad():
        x = x.float()
        y_pred = model(x)
    print("Input:", x.shape)
    print("Target:", y.shape, y)
    print("Prediction:", y_pred.shape, y_pred)
    break  # Only test one batch

Input: torch.Size([64, 32, 1])
Target: torch.Size([64, 1, 1]) tensor([[[54.9000]],

        [[55.3700]],

        [[55.3300]],

        [[15.5000]],

        [[57.1800]],

        [[61.9000]],

        [[15.7000]],

        [[15.4000]],

        [[16.0000]],

        [[59.4900]],

        [[68.3600]],

        [[16.3000]],

        [[60.1900]],

        [[15.8000]],

        [[66.8500]],

        [[66.0400]],

        [[71.8400]],

        [[16.1000]],

        [[16.1000]],

        [[15.7000]],

        [[72.4600]],

        [[64.4500]],

        [[16.3000]],

        [[16.1000]],

        [[10.3000]],

        [[ 9.7000]],

        [[63.7700]],

        [[ 9.7000]],

        [[65.5900]],

        [[ 9.7000]],

        [[63.7700]],

        [[ 9.7000]],

        [[61.2900]],

        [[61.5800]],

        [[ 9.7000]],

        [[60.9100]],

        [[ 9.8000]],

        [[ 9.8000]],

        [[ 9.8000]],

        [[ 9.7000]],

        [[61.1100]],

        [[ 9.8000]],

        [[59.7