# Create a dataset

In [1]:
import sys
import os
import torch
from torch.utils.data import IterableDataset

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..")))

from data.preprocessing.transform.dataset_builder import Builder
from data.preprocessing.transform.probabilistic_mixing_dataset import ProbabilisticMixingDataset
from data.preprocessing.downloader.gift_eval import load_gifteval_dataset_wrapper

In [22]:
class PostProcessingDataset(IterableDataset):
    def __init__(self, gift_eval_datasets, window_size=32, prediction_depth=1, seed=42, batch_size=64):
        super().__init__()

        dataset_dict = {
            str(i): (
                Builder(ds)
                .sliding_window(window_size + prediction_depth, step=window_size + prediction_depth)
                .map(lambda t: (t[:window_size], t[window_size:]))
                .build()
            )
            for i, ds in enumerate(gift_eval_datasets)
        }

        mixed_dataset = ProbabilisticMixingDataset(dataset_dict, seed=seed)

        def collate_list_of_tuples(data):
            features, targets = zip(*data)
            features = torch.stack(features)
            targets = torch.stack(targets)
            return features, targets

        self.ds = (
            Builder(mixed_dataset)
                .batch(batch_size)
                .map(collate_list_of_tuples)
                .build()
        )
        
    def __iter__(self):
        for item in iter(self.ds):
            yield item

In [23]:
ds = PostProcessingDataset([
    load_gifteval_dataset_wrapper(['bdg-2_fox']),
    load_gifteval_dataset_wrapper(['bdg-2_bear'])
])

# Define a model

In [24]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pytorch_lightning as pl
from torch import nn

class LinearRegressionModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(32, 1)

    def forward(self, x):
        # x: [batch_size, 32, 1] => squeeze last dim to [batch_size, 32]
        x = x.squeeze(-1)
        return self.linear(x)  # Output shape: [batch_size, 1]

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.float()
        y = y.float()
        y = y.squeeze(-1)  # shape: [batch_size, 1]
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Train a model

In [25]:
dataloader = DataLoader(ds, shuffle=False)

model = LinearRegressionModel()
trainer = pl.Trainer(max_epochs=1, accelerator="auto")
trainer.fit(model, dataloader)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type   | Params | Mode 
------------------------------------------
0 | linear | Linear | 33     | train
------------------------------------------
33        Trainable params
0         Non-trainable params
33        Total params
0.000     Total estimated model params size (MB)
1         Modules in train mode
0         Modules in eval mode


Training: |                                                                                                   …

`Trainer.fit` stopped: `max_epochs=1` reached.


# Test that a model predicts something

In [26]:
test_loader = DataLoader(ds, shuffle=False)

model.eval()

for x, y in test_loader:
    with torch.no_grad():
        x = x.float()
        y_pred = model(x)
    print("Input:", x.shape)
    print("Target:", y.shape, y)
    print("Prediction:", y_pred.shape, y_pred)
    break  # Only test one batch

Input: torch.Size([1, 64, 32, 1])
Target: torch.Size([1, 64, 1, 1]) tensor([[[[15.5000]],

         [[54.9000]],

         [[55.3700]],

         [[55.3300]],

         [[57.1800]],

         [[15.7000]],

         [[61.9000]],

         [[15.4000]],

         [[16.0000]],

         [[16.3000]],

         [[59.4900]],

         [[68.3600]],

         [[60.1900]],

         [[15.8000]],

         [[66.8500]],

         [[66.0400]],

         [[16.1000]],

         [[16.1000]],

         [[15.7000]],

         [[16.3000]],

         [[71.8400]],

         [[16.1000]],

         [[72.4600]],

         [[64.4500]],

         [[63.7700]],

         [[10.3000]],

         [[ 9.7000]],

         [[65.5900]],

         [[63.7700]],

         [[61.2900]],

         [[ 9.7000]],

         [[ 9.7000]],

         [[ 9.7000]],

         [[ 9.7000]],

         [[61.5800]],

         [[ 9.8000]],

         [[ 9.8000]],

         [[60.9100]],

         [[61.1100]],

         [[ 9.8000]],

         [[ 