In [1]:
!git clone https://github.com/AlesSrsen/spacetimeformer.git

Cloning into 'spacetimeformer'...
remote: Enumerating objects: 670, done.[K
remote: Counting objects: 100% (301/301), done.[K
remote: Compressing objects: 100% (117/117), done.[K
remote: Total 670 (delta 219), reused 202 (delta 184), pack-reused 369 (from 1)[K
Receiving objects: 100% (670/670), 16.35 MiB | 17.42 MiB/s, done.
Resolving deltas: 100% (358/358), done.


In [2]:
!cd /content/spacetimeformer
!ls /content/spacetimeformer

crypto_preparation.ipynb	README.md	  setup.py
crypto_preparation_srsen.ipynb	readme_media	  spacetimeformer
LICENSE				requirements.txt  spacetimeformer_google_colab.ipynb


In [3]:
# The runtime should be restarted after this step
%pip install -r /content/spacetimeformer/requirements.txt
%pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
%pip install pytorch-lightning==1.9.0 netCDF4 omegaconf performer_pytorch torchmetrics==0.9.1

Collecting cmdstanpy==0.9.68 (from -r /content/spacetimeformer/requirements.txt (line 2))
  Downloading cmdstanpy-0.9.68-py3-none-any.whl.metadata (3.7 kB)
Collecting pystan~=2.19.1.1 (from -r /content/spacetimeformer/requirements.txt (line 3))
  Downloading pystan-2.19.1.1.tar.gz (16.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.2/16.2 MB[0m [31m35.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting convertdate>=2.1.2 (from -r /content/spacetimeformer/requirements.txt (line 7))
  Downloading convertdate-2.4.0-py3-none-any.whl.metadata (8.3 kB)
Collecting performer-pytorch (from -r /content/spacetimeformer/requirements.txt (line 9))
  Downloading performer_pytorch-1.1.4-py3-none-any.whl.metadata (763 bytes)
Collecting nystrom-attention (from -r /content/spacetimeformer/requirements.txt (line 11))
  Downloading nystrom_attention-0.0.12-py3-none-any.whl.metadata (657 bytes)
Collecting pytorch-lightning==1.6 (f

In [8]:
%pip install /content/spacetimeformer/
# If editable then
# !pip install -e .

Processing ./spacetimeformer
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: spacetimeformer
  Building wheel for spacetimeformer (setup.py) ... [?25l[?25hdone
  Created wheel for spacetimeformer: filename=spacetimeformer-1.5.0-py3-none-any.whl size=95649 sha256=0df423c89021885ecc9510fcc7bdc37a0fb325b774e80f9a0600c9ed03d371f9
  Stored in directory: /tmp/pip-ephem-wheel-cache-h01rtibo/wheels/7c/50/8b/5ddea8d5fbf5266363fe25364df45b57c105c646f99b175bc6
Successfully built spacetimeformer
Installing collected packages: spacetimeformer
Successfully installed spacetimeformer-1.5.0


In [5]:
!cd /content

# Pull and update content from the repository

# Training code
Code below is extracted code from train.py made to work in colab without calling the command line. This gives us the ability to easily adjust the code and run it in colab.

In [1]:
# prompt: upload file

from google.colab import files
uploaded = files.upload()
for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))


Saving preprocessed_12_cols_eth_timeseries.csv to preprocessed_12_cols_eth_timeseries.csv
User uploaded file "preprocessed_12_cols_eth_timeseries.csv" with length 3306679 bytes


In [1]:
from argparse import ArgumentParser
import os
import uuid

import pytorch_lightning as pl
import spacetimeformer as stf

_MODELS = ["spacetimeformer", "mtgnn", "heuristic", "lstm", "lstnet", "linear", "s4"]


def create_adjusted_parser(model):
    # Throw error now before we get confusing parser issues
    assert (
        model in _MODELS
    ), f"Unrecognized model (`{model}`). Options include: {_MODELS}"

    parser = ArgumentParser()
    parser.add_argument("model")
    parser.add_argument("dset")

    stf.data.CSVTimeSeries.add_cli(parser)
    stf.data.CSVTorchDset.add_cli(parser)

    stf.data.DataModule.add_cli(parser)

    if model == "lstm":
        stf.lstm_model.LSTM_Forecaster.add_cli(parser)
        stf.callbacks.TeacherForcingAnnealCallback.add_cli(parser)
    elif model == "lstnet":
        stf.lstnet_model.LSTNet_Forecaster.add_cli(parser)
    elif model == "mtgnn":
        stf.mtgnn_model.MTGNN_Forecaster.add_cli(parser)
    elif model == "heuristic":
        stf.heuristic_model.Heuristic_Forecaster.add_cli(parser)
    elif model == "spacetimeformer":
        stf.spacetimeformer_model.Spacetimeformer_Forecaster.add_cli(parser)
    elif model == "linear":
        stf.linear_model.Linear_Forecaster.add_cli(parser)
    elif model == "s4":
        stf.s4_model.S4_Forecaster.add_cli(parser)

    stf.callbacks.TimeMaskedLossCallback.add_cli(parser)

    parser.add_argument("--wandb", action="store_true")
    parser.add_argument("--plot", action="store_true")
    parser.add_argument("--plot_samples", type=int, default=8)
    parser.add_argument("--attn_plot", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--run_name", type=str, required=True)
    parser.add_argument("--accumulate", type=int, default=1)
    parser.add_argument("--val_check_interval", type=float, default=1.0)
    parser.add_argument("--limit_val_batches", type=float, default=1.0)
    parser.add_argument("--no_earlystopping", action="store_true")
    parser.add_argument("--patience", type=int, default=5)
    parser.add_argument(
        "--trials", type=int, default=1, help="How many consecutive trials to run"
    )

    return parser


def create_model(config):
    x_dim, yc_dim, yt_dim = None, None, None
    if config.dset == "dmts_crypto":
        x_dim = 6 # Time dimension
        yc_dim = 9 # Context variables dimensions
        yt_dim = 9 # Target variables dimensions
    assert x_dim is not None
    assert yc_dim is not None
    assert yt_dim is not None

    print(f'Using x_dim={x_dim} yc_dim={yc_dim} yt_dim={yt_dim}')

    if config.model == "lstm":
        forecaster = stf.lstm_model.LSTM_Forecaster(
            # encoder
            d_x=x_dim,
            d_yc=yc_dim,
            d_yt=yt_dim,
            time_emb_dim=config.time_emb_dim,
            hidden_dim=config.hidden_dim,
            n_layers=config.n_layers,
            dropout_p=config.dropout_p,
            # training
            learning_rate=config.learning_rate,
            teacher_forcing_prob=config.teacher_forcing_start,
            l2_coeff=config.l2_coeff,
            loss=config.loss,
            linear_window=config.linear_window,
            use_revin=config.use_revin,
            linear_shared_weights=config.linear_shared_weights,
            use_seasonal_decomp=config.use_seasonal_decomp,
        )

    elif config.model == "heuristic":
        forecaster = stf.heuristic_model.Heuristic_Forecaster(
            d_x=x_dim,
            d_yc=yc_dim,
            d_yt=yt_dim,
            context_points=config.context_points,
            target_points=config.target_points,
            loss=config.loss,
            method=config.method,
        )
    elif config.model == "mtgnn":
        forecaster = stf.mtgnn_model.MTGNN_Forecaster(
            d_x=x_dim,
            d_yc=yc_dim,
            d_yt=yt_dim,
            context_points=config.context_points,
            target_points=config.target_points,
            gcn_depth=config.gcn_depth,
            dropout_p=config.dropout_p,
            node_dim=config.node_dim,
            dilation_exponential=config.dilation_exponential,
            conv_channels=config.conv_channels,
            subgraph_size=config.subgraph_size,
            skip_channels=config.skip_channels,
            end_channels=config.end_channels,
            residual_channels=config.residual_channels,
            layers=config.layers,
            propalpha=config.propalpha,
            tanhalpha=config.tanhalpha,
            learning_rate=config.learning_rate,
            kernel_size=config.kernel_size,
            l2_coeff=config.l2_coeff,
            time_emb_dim=config.time_emb_dim,
            loss=config.loss,
            linear_window=config.linear_window,
            linear_shared_weights=config.linear_shared_weights,
            use_seasonal_decomp=config.use_seasonal_decomp,
            use_revin=config.use_revin,
        )
    elif config.model == "lstnet":
        forecaster = stf.lstnet_model.LSTNet_Forecaster(
            d_x=x_dim,
            d_yc=yc_dim,
            d_yt=yt_dim,
            context_points=config.context_points,
            hidRNN=config.hidRNN,
            hidCNN=config.hidCNN,
            hidSkip=config.hidSkip,
            CNN_kernel=config.CNN_kernel,
            skip=config.skip,
            dropout_p=config.dropout_p,
            output_fun=config.output_fun,
            learning_rate=config.learning_rate,
            l2_coeff=config.l2_coeff,
            loss=config.loss,
            linear_window=config.linear_window,
            use_revin=config.use_revin,
        )
    elif config.model == "spacetimeformer":
        if hasattr(config, "context_points") and hasattr(config, "target_points"):
            max_seq_len = config.context_points + config.target_points
        elif hasattr(config, "max_len"):
            max_seq_len = config.max_len
        else:
            raise ValueError("Undefined max_seq_len")
        forecaster = stf.spacetimeformer_model.Spacetimeformer_Forecaster(
            d_x=x_dim,
            d_yc=yc_dim,
            d_yt=yt_dim,
            max_seq_len=max_seq_len,
            start_token_len=config.start_token_len,
            attn_factor=config.attn_factor,
            d_model=config.d_model,
            d_queries_keys=config.d_qk,
            d_values=config.d_v,
            n_heads=config.n_heads,
            e_layers=config.enc_layers,
            d_layers=config.dec_layers,
            d_ff=config.d_ff,
            dropout_emb=config.dropout_emb,
            dropout_attn_out=config.dropout_attn_out,
            dropout_attn_matrix=config.dropout_attn_matrix,
            dropout_qkv=config.dropout_qkv,
            dropout_ff=config.dropout_ff,
            pos_emb_type=config.pos_emb_type,
            use_final_norm=not config.no_final_norm,
            global_self_attn=config.global_self_attn,
            local_self_attn=config.local_self_attn,
            global_cross_attn=config.global_cross_attn,
            local_cross_attn=config.local_cross_attn,
            performer_kernel=config.performer_kernel,
            performer_redraw_interval=config.performer_redraw_interval,
            attn_time_windows=config.attn_time_windows,
            use_shifted_time_windows=config.use_shifted_time_windows,
            norm=config.norm,
            activation=config.activation,
            init_lr=config.init_lr,
            base_lr=config.base_lr,
            warmup_steps=config.warmup_steps,
            decay_factor=config.decay_factor,
            initial_downsample_convs=config.initial_downsample_convs,
            intermediate_downsample_convs=config.intermediate_downsample_convs,
            embed_method=config.embed_method,
            l2_coeff=config.l2_coeff,
            loss=config.loss,
            class_loss_imp=config.class_loss_imp,
            recon_loss_imp=config.recon_loss_imp,
            time_emb_dim=config.time_emb_dim,
            null_value=config.null_value,
            pad_value=config.pad_value,
            linear_window=config.linear_window,
            use_revin=config.use_revin,
            linear_shared_weights=config.linear_shared_weights,
            use_seasonal_decomp=config.use_seasonal_decomp,
            use_val=not config.no_val,
            use_time=not config.no_time,
            use_space=not config.no_space,
            use_given=not config.no_given,
            recon_mask_skip_all=config.recon_mask_skip_all,
            recon_mask_max_seq_len=config.recon_mask_max_seq_len,
            recon_mask_drop_seq=config.recon_mask_drop_seq,
            recon_mask_drop_standard=config.recon_mask_drop_standard,
            recon_mask_drop_full=config.recon_mask_drop_full,
        )
    elif config.model == "linear":
        forecaster = stf.linear_model.Linear_Forecaster(
            d_x=x_dim,
            d_yc=yc_dim,
            d_yt=yt_dim,
            context_points=config.context_points,
            learning_rate=config.learning_rate,
            l2_coeff=config.l2_coeff,
            loss=config.loss,
            linear_window=config.linear_window,
            linear_shared_weights=config.linear_shared_weights,
            use_revin=config.use_revin,
            use_seasonal_decomp=config.use_seasonal_decomp,
        )
    elif config.model == "s4":
        forecaster = stf.s4_model.S4_Forecaster(
            context_points=config.context_points,
            target_points=config.target_points,
            d_state=config.d_state,
            d_model=config.d_model,
            d_x=x_dim,
            d_yc=yc_dim,
            d_yt=yt_dim,
            layers=config.layers,
            time_emb_dim=config.time_emb_dim,
            channels=config.channels,
            dropout_p=config.dropout_p,
            learning_rate=config.learning_rate,
            l2_coeff=config.l2_coeff,
            loss=config.loss,
            linear_window=config.linear_window,
            linear_shared_weights=config.linear_shared_weights,
            use_revin=config.use_revin,
            use_seasonal_decomp=config.use_seasonal_decomp,
        )

    return forecaster


def create_dset(config):
    INV_SCALER = lambda x: x
    SCALER = lambda x: x
    NULL_VAL = None
    PLOT_VAR_IDXS = None
    PLOT_VAR_NAMES = None
    PAD_VAL = None

    time_col_name = "Datetime"
    data_path = config.data_path
    time_features = ["year", "month", "day", "weekday", "hour", "minute"]

    if config.dset == "dmts_crypto":  # DMTS Modification
        if data_path == "auto":
            raise ValueError("Please specify a datapath.")
        target_cols = ["open","close","high","low","volume","volume_24h","market_cap","circulating_supply","volatility"]
        time_col_name = "time"

    print(config)

    dset = stf.data.CSVTimeSeries(
        data_path=data_path,
        target_cols=target_cols,
        ignore_cols="all",
        time_col_name=time_col_name,
        time_features=time_features,
        val_split=0.2,
        test_split=0.2,
    )
    DATA_MODULE = stf.data.DataModule(
        datasetCls=stf.data.CSVTorchDset,
        dataset_kwargs={
            "csv_time_series": dset,
            "context_points": config.context_points,
            "target_points": config.target_points,
            "time_resolution": config.time_resolution,
        },
        batch_size=config.batch_size,
        workers=config.workers,
        overfit=config.overfit,
    )
    INV_SCALER = dset.reverse_scaling
    SCALER = dset.apply_scaling
    NULL_VAL = None

    return (
        DATA_MODULE,
        INV_SCALER,
        SCALER,
        NULL_VAL,
        PLOT_VAR_IDXS,
        PLOT_VAR_NAMES,
        PAD_VAL,
    )


def create_callbacks(config, save_dir):
    filename = f"{config.run_name}_" + str(uuid.uuid1()).split("-")[0]
    model_ckpt_dir = os.path.join(save_dir, filename)
    config.model_ckpt_dir = model_ckpt_dir
    saving = pl.callbacks.ModelCheckpoint(
        dirpath=model_ckpt_dir,
        monitor="val/loss",
        mode="min",
        filename=f"{config.run_name}" + "{epoch:02d}",
        save_top_k=1,
        auto_insert_metric_name=True,
    )
    callbacks = [saving]

    if not config.no_earlystopping:
        callbacks.append(
            pl.callbacks.early_stopping.EarlyStopping(
                monitor="val/loss",
                patience=config.patience,
            )
        )

    if config.wandb:
        callbacks.append(pl.callbacks.LearningRateMonitor())

    if config.model == "lstm":
        callbacks.append(
            stf.callbacks.TeacherForcingAnnealCallback(
                start=config.teacher_forcing_start,
                end=config.teacher_forcing_end,
                steps=config.teacher_forcing_anneal_steps,
            )
        )
    if config.time_mask_loss:
        callbacks.append(
            stf.callbacks.TimeMaskedLossCallback(
                start=config.time_mask_start,
                end=config.time_mask_end,
                steps=config.time_mask_anneal_steps,
            )
        )
    return callbacks

## Info for the dataset setup
https://chatgpt.com/share/674dd819-2574-800f-b615-deb52a1fec22

```
In the context of the code provided, the parameters x_dim, yc_dim, and yt_dim represent dimensions related to the data used for training models on time series datasets. Here is the breakdown:

x_dim:

This represents the dimension of the input features.
For the exchange dataset, x_dim is set to 6. This typically means there are 6 features or time-related attributes used as input features (e.g., year, month, day, weekday, hour, minute).
yc_dim:

This represents the dimension of the context variables or features that are available in both the input (past) and the output (future) during training and forecasting.
For the exchange dataset, yc_dim is 8. This corresponds to the 8 target columns in the dataset: "Australia", "United Kingdom", "Canada", "Switzerland", "China", "Japan", "New Zealand", "Singapore".
yt_dim:

This represents the dimension of the target variables or features being forecasted during training and inference.
For the exchange dataset, yt_dim is also 8, matching the number of target columns.
Summary for the exchange Dataset:
x_dim = 6: Refers to the input feature dimensions, typically derived from time-related attributes.
yc_dim = 8: Refers to the number of context variables, corresponding to the 8 exchange rates.
yt_dim = 8: Refers to the number of target variables being forecasted, also corresponding to the 8 exchange rates.
These parameters are critical for configuring models to understand the input-output relationships in time series forecasting tasks.
```

In [3]:
!mkdir -p /content/wandb

In [4]:
!wandb login --relogin

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [2]:
# Training setup
STF_WANDB_ACCT = "dmts"
STF_WANDB_PROJ = "DMTS_UPM"
STF_LOG_DIR = "/content/wandb/"

RUN_NAME="TSDM_-_Full_run_"

DATASET_PATH='/content/preprocessed_12_cols_eth_timeseries.csv'

ARGS_SPACETIMEFORMER = [
        "spacetimeformer",
        "dmts_crypto",
        "--data_path",
        DATASET_PATH,
        "--wandb",
        "--batch_size",
        "24",
        "--attn_plot",
        "--embed_method",
        "spatio-temporal",
        "--local_self_attn",
        "full",
        "--local_cross_attn",
        "full",
        "--global_self_attn",
        "full",
        "--global_cross_attn",
        "full",
        "--run_name",
        RUN_NAME + "SPACETIMEFORMER",
        "--context_points",
        "10",
        "--gpus",
        "0",
    ]

ARGS_LSTM = [
    "lstm",
    "dmts_crypto",
    "--data_path",
    DATASET_PATH,
    # "--wandb",
    "--batch_size",
    "24",
    "--run_name",
    RUN_NAME + "LSTM",
    "--context_points",
    "10",
    "--gpus",
    "0",
]

ARGS_SPACETIMEFORMER_TEMPORAL = [
    "spacetimeformer",
    "dmts_crypto",
    "--data_path",
    DATASET_PATH,
    "--wandb",
    "--batch_size",
    "24",
    "--embed_method",
    "temporal",
    "--local_self_attn",
    "none",
    "--local_cross_attn",
    "none",
    "--global_self_attn",
    "full",
    "--global_cross_attn",
    "full",
    "--run_name",
    RUN_NAME + "SPACETIMEFORMER_TEMPORAL",
    "--context_points",
    "10",
    "--gpus",
    "0",
]


In [3]:
ARGS=ARGS_SPACETIMEFORMER
MODEL=ARGS[0]


ARG_CONFIG = create_adjusted_parser(MODEL).parse_args(
    ARGS
)

In [4]:
log_dir = STF_LOG_DIR
if log_dir is None:
    log_dir = "./data/STF_LOG_DIR"
    print(
        "Using default wandb log dir path of ./data/STF_LOG_DIR. This can be adjusted with the environment variable `STF_LOG_DIR`"
    )
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

if ARG_CONFIG.wandb:
    import wandb

    project = STF_WANDB_PROJ
    entity = STF_WANDB_ACCT
    assert (
        project is not None and entity is not None
    ), "Please set environment variables `STF_WANDB_ACCT` and `STF_WANDB_PROJ` with \n\
        your wandb user/organization name and project title, respectively."
    experiment = wandb.init(
        project=project,
        entity=entity,
        config=ARG_CONFIG,
        dir=log_dir,
        reinit=True,
    )
    config = wandb.config
    wandb.run.name = ARG_CONFIG.run_name
    wandb.run.save()
    logger = pl.loggers.WandbLogger(
        experiment=experiment,
        save_dir=log_dir,
    )

# Dset
(
    data_module,
    inv_scaler,
    scaler,
    null_val,
    plot_var_idxs,
    plot_var_names,
    pad_val,
) = create_dset(ARG_CONFIG)

# Model
ARG_CONFIG.null_value = null_val
ARG_CONFIG.pad_value = pad_val
forecaster = create_model(ARG_CONFIG)
forecaster.set_inv_scaler(inv_scaler)
forecaster.set_scaler(scaler)
forecaster.set_null_value(null_val)

# Callbacks
callbacks = create_callbacks(ARG_CONFIG, save_dir=log_dir)
test_samples = next(iter(data_module.test_dataloader()))

if ARG_CONFIG.wandb and ARG_CONFIG.plot:
    callbacks.append(
        stf.plot.PredictionPlotterCallback(
            test_samples,
            var_idxs=plot_var_idxs,
            var_names=plot_var_names,
            pad_val=pad_val,
            total_samples=min(ARG_CONFIG.plot_samples, ARG_CONFIG.batch_size),
        )
    )

if ARG_CONFIG.wandb and ARG_CONFIG.model == "spacetimeformer" and ARG_CONFIG.attn_plot:

    callbacks.append(
        stf.plot.AttentionMatrixCallback(
            test_samples,
            layer=0,
            total_samples=min(16, ARG_CONFIG.batch_size),
        )
    )

if ARG_CONFIG.wandb:
    config.update(ARG_CONFIG)
    logger.log_hyperparams(config)

if ARG_CONFIG.val_check_interval <= 1.0:
    val_control = {"val_check_interval": ARG_CONFIG.val_check_interval}
else:
    val_control = {"check_val_every_n_epoch": int(ARG_CONFIG.val_check_interval)}

trainer = pl.Trainer(
    gpus=ARG_CONFIG.gpus,
    callbacks=callbacks,
    logger=logger if ARG_CONFIG.wandb else None,
    accelerator="cuda",
    gradient_clip_val=ARG_CONFIG.grad_clip_norm,
    gradient_clip_algorithm="norm",
    overfit_batches=20 if ARG_CONFIG.debug else 0,
    accumulate_grad_batches=ARG_CONFIG.accumulate,
    sync_batchnorm=True,
    limit_val_batches=ARG_CONFIG.limit_val_batches,
    **val_control,
)

# Train
trainer.fit(forecaster, datamodule=data_module)

# Test
trainer.test(datamodule=data_module, ckpt_path="best")

# Predict (only here as a demo and test)
# forecaster.to("cuda")
# xc, yc, xt, _ = test_samples
# yt_pred = forecaster.predict(xc, yc, xt)

if ARG_CONFIG.wandb:
    experiment.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msrsen[0m ([33mdmts[0m). Use [1m`wandb login --relogin`[0m to force relogin


  main_df["Month"] = dates.apply(
  main_df["Day"] = dates.apply(lambda row: 2.0 * ((row.day - 1) / 30.0) - 1.0, 1)
  main_df["Weekday"] = dates.apply(


Namespace(model='spacetimeformer', dset='dmts_crypto', data_path='/content/preprocessed_12_cols_eth_timeseries.csv', context_points=10, target_points=32, time_resolution=1, batch_size=24, workers=6, overfit=False, gpus=[0], l2_coeff=1e-06, learning_rate=0.0001, grad_clip_norm=0, linear_window=0, use_revin=False, loss='mse', linear_shared_weights=False, use_seasonal_decomp=False, start_token_len=0, d_model=200, d_qk=200, d_v=200, n_heads=4, enc_layers=3, dec_layers=3, d_ff=800, attn_factor=5, dropout_emb=0.2, dropout_attn_matrix=0.0, dropout_qkv=0.0, dropout_ff=0.3, dropout_attn_out=0.0, global_self_attn='full', global_cross_attn='full', local_self_attn='full', local_cross_attn='full', activation='gelu', norm='batch', init_lr=1e-10, base_lr=0.0005, warmup_steps=0, decay_factor=0.25, initial_downsample_convs=0, class_loss_imp=0.1, recon_loss_imp=0.0, intermediate_downsample_convs=0, time_emb_dim=6, performer_kernel='relu', performer_redraw_interval=100, embed_method='spatio-temporal', at

  main_df["Hour"] = dates.apply(lambda row: 2.0 * ((row.hour) / 23.0) - 1.0, 1)
  main_df["Minute"] = dates.apply(


Using x_dim=6 yc_dim=9 yt_dim=9
Forecaster
	L2: 1e-06
	Linear Window: 0
	Linear Shared Weights: False
	RevIN: False
	Decomposition: False
GlobalSelfAttn: AttentionLayer(
  (inner_attention): FullAttention(
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (query_projection): Linear(in_features=200, out_features=800, bias=True)
  (key_projection): Linear(in_features=200, out_features=800, bias=True)
  (value_projection): Linear(in_features=200, out_features=800, bias=True)
  (out_projection): Linear(in_features=800, out_features=200, bias=True)
  (dropout_qkv): Dropout(p=0.0, inplace=False)
)
GlobalCrossAttn: AttentionLayer(
  (inner_attention): FullAttention(
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (query_projection): Linear(in_features=200, out_features=800, bias=True)
  (key_projection): Linear(in_features=200, out_features=800, bias=True)
  (value_projection): Linear(in_features=200, out_features=800, bias=True)
  (out_projection): Linear(in_features=800, out_features=20

  rank_zero_deprecation(
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
INFO:pytorch_lightning.utilities.rank_zero:`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
  rank_zero_warn(
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name            | Type            | Params
----------------------------------------------------
0 | spacetimeformer | Spacetimeformer | 13.5 M
----------------------------------------------------
13.5 M    Trainable par

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/wandb/TSDM_-_Full_run_SPACETIMEFORMER_74ce1318/TSDM_-_Full_run_SPACETIMEFORMERepoch=00.ckpt
  return torch.load(f, map_location=map_location)
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from checkpoint at /content/wandb/TSDM_-_Full_run_SPACETIMEFORMER_74ce1318/TSDM_-_Full_run_SPACETIMEFORMERepoch=00.ckpt


Testing: 0it [00:00, ?it/s]

0,1
epoch,▁▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇█
global_step,▁▁▁▂▂▂▃▃▃▅▅▅▆▆▆▇▇▇███
lr-AdamW,███▁▁▁
test/acc,▁
test/class_loss,▁
test/forecast_loss,▁
test/loss,▁
test/mae,▁
test/mape,▁
test/mse,▁

0,1
epoch,6.0
global_step,3210.0
lr-AdamW,0.00013
test/acc,1.0
test/class_loss,0.01085
test/forecast_loss,0.25639
test/loss,0.25747
test/mae,300434678159.59894
test/mape,484.22862
test/mse,9.471357197300885e+23


In [10]:
forecaster.to("cuda")
xc, yc, xt, _ = test_samples
yt_pred = forecaster.predict(xc, yc, xt)
yt_pred

tensor([[[2.1500e+03, 1.8671e+03, 1.5819e+03,  ..., 2.8938e+11,
          1.1375e+08, 7.3409e-02],
         [1.9893e+03, 1.7782e+03, 1.3432e+03,  ..., 2.0533e+11,
          1.1390e+08, 7.1906e-02],
         [2.4736e+03, 1.7356e+03, 1.8310e+03,  ..., 3.0912e+11,
          1.1380e+08, 6.7961e-02],
         ...,
         [2.8288e+03, 1.7910e+03, 1.4048e+03,  ..., 2.9388e+11,
          1.1404e+08, 6.9880e-02],
         [1.7430e+03, 1.7546e+03, 1.3429e+03,  ..., 1.0815e+11,
          1.1412e+08, 6.0804e-02],
         [1.6250e+03, 1.6261e+03, 1.2926e+03,  ..., 1.0608e+11,
          1.1398e+08, 5.9591e-02]],

        [[2.5045e+03, 1.8245e+03, 1.5920e+03,  ..., 3.3032e+11,
          1.1403e+08, 7.2349e-02],
         [1.5858e+03, 1.8009e+03, 1.4385e+03,  ..., 2.7699e+11,
          1.1402e+08, 6.5591e-02],
         [2.8080e+03, 1.8610e+03, 2.1380e+03,  ..., 3.1118e+11,
          1.1402e+08, 7.0684e-02],
         ...,
         [1.8174e+03, 1.8386e+03, 2.1127e+03,  ..., 2.8818e+11,
          1.139

In [12]:
xt

tensor([[[-0.2000,  1.0000, -1.0000,  0.2727,  0.6667,  0.6667],
         [-0.1333, -1.0000, -1.0000,  0.2727,  1.0000,  0.6667],
         [-0.1333, -0.9130, -1.0000,  0.2727,  1.0000,  0.6667],
         ...,
         [-0.0667, -0.6522, -1.0000,  0.2727, -1.0000,  0.6667],
         [-0.0667, -0.5652, -1.0000,  0.2727, -1.0000,  0.6667],
         [-0.0667, -0.4783, -1.0000,  0.2727, -1.0000,  0.6667]],

        [[-0.1333, -1.0000, -1.0000,  0.2727,  1.0000,  0.6667],
         [-0.1333, -0.9130, -1.0000,  0.2727,  1.0000,  0.6667],
         [-0.1333, -0.8261, -1.0000,  0.2727,  1.0000,  0.6667],
         ...,
         [-0.0667, -0.5652, -1.0000,  0.2727, -1.0000,  0.6667],
         [-0.0667, -0.4783, -1.0000,  0.2727, -1.0000,  0.6667],
         [-0.0667, -0.3913, -1.0000,  0.2727, -1.0000,  0.6667]],

        [[-0.1333, -0.9130, -1.0000,  0.2727,  1.0000,  0.6667],
         [-0.1333, -0.8261, -1.0000,  0.2727,  1.0000,  0.6667],
         [-0.1333, -0.7391, -1.0000,  0.2727,  1.0000,  0.