# Settings

In [1]:
from argparse import ArgumentParser

import random
import sys
import warnings
import os
import uuid

import pytorch_lightning as pl
import torch

import spacetimeformer as stf

  from .autonotebook import tqdm as notebook_tqdm


# Baseline codes

python train.py spacetimeformer solar_energy --context_points 168 --target_points 24 --d_model 100 --d_ff 400 --enc_layers 5 --dec_layers 5 --l2_coeff 1e-3 --dropout_ff .2 --dropout_emb .1 --d_qk 20 --d_v 20 --n_heads 6 --run_name spatiotemporal_al_solar --batch_size 32 --class_loss_imp 0 --initial_downsample_convs 1 --decay_factor .8 --warmup_steps 1000

It is conceptually similar to [Informer](https://arxiv.org/abs/2012.07436). Set the `embed_method = temporal`. Spacetimeformer has many configurable options and we try to provide a thorough explanation with the commandline `-h` instructions.

In [2]:
model =["spacetimeformer"]

DSETS = ["solar_energy"]
data_path = "./data/solar_AL_converted.csv"
target_cols = [str(i) for i in range(137)]

- `solar_energy`: Is the codebase's name for the time series benchmark more commonly called "AL Solar."


## Utils

In [2]:
def create_parser():
    # model = sys.argv[1] # spacetimeformer
    # dset = sys.argv[2] # asos

    parser = ArgumentParser()
    parser.add_argument("--model" ,type=str, default = "spacetimeformer")
    parser.add_argument("--dset", type=str, default= "asos")

    #only use asos
    stf.data.CSVTimeSeries.add_cli(parser) #parser.add_argument("--data_path", type=str, default="auto")
    stf.data.CSVTorchDset.add_cli(parser) # context, target points, time resolution parser
    stf.data.DataModule.add_cli(parser) # batchsize, worker, overfit

    # spacetimeformer
    stf.spacetimeformer_model.Spacetimeformer_Forecaster.add_cli(parser)
    

    stf.callbacks.TimeMaskedLossCallback.add_cli(parser)


    parser.add_argument("--wandb", action="store_true")
    parser.add_argument("--plot", action="store_true")
    parser.add_argument("--plot_samples", type=int, default=8)
    parser.add_argument("--attn_plot", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--run_name", type=str, required=True)
    parser.add_argument("--accumulate", type=int, default=1)
    parser.add_argument("--val_check_interval", type=float, default=1.0)
    parser.add_argument("--limit_val_batches", type=float, default=1.0)
    parser.add_argument("--no_earlystopping", action="store_true")
    parser.add_argument("--patience", type=int, default=5)
    parser.add_argument(
        "--trials", type=int, default=1, help="How many consecutive trials to run"
    )

    if len(sys.argv) > 3 and sys.argv[3] == "-h":
        parser.print_help()
        sys.exit(0)

    return parser

In [3]:
def create_model(config):

    #as to solar energy dataset
    x_dim, yc_dim, yt_dim = None, None, None
    x_dim = 6
    yc_dim = 137
    yt_dim = 137


    #config.model == "spacetimeformer":

    #setting for target length and context length
    if hasattr(config, "context_points") and hasattr(config, "target_points"):
        max_seq_len = config.context_points + config.target_points
    elif hasattr(config, "max_len"):
        max_seq_len = config.max_len
    else:
        raise ValueError("Undefined max_seq_len")
    



    forecaster = stf.spacetimeformer_model.Spacetimeformer_Forecaster(
        d_x=x_dim,
        d_yc=yc_dim,
        d_yt=yt_dim,
        max_seq_len=max_seq_len,
        start_token_len=config.start_token_len,


        attn_factor=config.attn_factor,
        d_model=config.d_model,
        d_queries_keys=config.d_qk,
        d_values=config.d_v,
        n_heads=config.n_heads,
        e_layers=config.enc_layers,
        d_layers=config.dec_layers,
        d_ff=config.d_ff,
        dropout_emb=config.dropout_emb,
        dropout_attn_out=config.dropout_attn_out,
        dropout_attn_matrix=config.dropout_attn_matrix,
        dropout_qkv=config.dropout_qkv,
        dropout_ff=config.dropout_ff,
        pos_emb_type=config.pos_emb_type,
        use_final_norm=not config.no_final_norm,
        global_self_attn=config.global_self_attn,
        local_self_attn=config.local_self_attn,
        global_cross_attn=config.global_cross_attn,
        local_cross_attn=config.local_cross_attn,
        performer_kernel=config.performer_kernel,
        performer_redraw_interval=config.performer_redraw_interval,
        attn_time_windows=config.attn_time_windows,
        use_shifted_time_windows=config.use_shifted_time_windows,
        norm=config.norm,
        activation=config.activation,
        init_lr=config.init_lr,
        base_lr=config.base_lr,
        warmup_steps=config.warmup_steps,
        decay_factor=config.decay_factor,
        initial_downsample_convs=config.initial_downsample_convs,
        intermediate_downsample_convs=config.intermediate_downsample_convs,
        embed_method=config.embed_method,
        l2_coeff=config.l2_coeff,
        loss=config.loss,
        class_loss_imp=config.class_loss_imp,
        recon_loss_imp=config.recon_loss_imp,
        time_emb_dim=config.time_emb_dim,
        null_value=config.null_value,
        pad_value=config.pad_value,
        linear_window=config.linear_window,
        use_revin=config.use_revin,
        linear_shared_weights=config.linear_shared_weights,
        use_seasonal_decomp=config.use_seasonal_decomp,
        use_val=not config.no_val,
        use_time=not config.no_time,
        use_space=not config.no_space,
        use_given=not config.no_given,
        recon_mask_skip_all=config.recon_mask_skip_all,
        recon_mask_max_seq_len=config.recon_mask_max_seq_len,
        recon_mask_drop_seq=config.recon_mask_drop_seq,
        recon_mask_drop_standard=config.recon_mask_drop_standard,
        recon_mask_drop_full=config.recon_mask_drop_full,
    )



    return forecaster

In [4]:
def create_dset(config):
    INV_SCALER = lambda x: x
    SCALER = lambda x: x
    NULL_VAL = None
    PLOT_VAR_IDXS = None
    PLOT_VAR_NAMES = None
    PAD_VAL = None




    time_col_name = "Datetime"
    data_path = config.data_path
    time_features = ["year", "month", "day", "weekday", "hour", "minute"]

    
    if config.dset == "asos":
        if data_path == "auto":
            data_path = "./data/temperature-v1.csv"
        target_cols = ["ABI", "AMA", "ACT", "ALB", "JFK", "LGA"]
    
    elif config.dset == "solar_energy":
        if data_path == "auto":
            data_path = "./data/solar_AL_converted.csv"
        target_cols = [str(i) for i in range(137)]
  

    dset = stf.data.CSVTimeSeries(
        data_path=data_path,
        target_cols=target_cols,
        ignore_cols="all",
        time_col_name=time_col_name,
        time_features=time_features,
        val_split=0.2,
        test_split=0.2,
    )


    
    DATA_MODULE = stf.data.DataModule(
        datasetCls=stf.data.CSVTorchDset,
        dataset_kwargs={
            "csv_time_series": dset,
            "context_points": config.context_points,
            "target_points": config.target_points,
            "time_resolution": config.time_resolution,
        },
        batch_size=config.batch_size,
        workers=config.workers,
        overfit=config.overfit, # arg는 from global parser
    )


    ##전처리 하는 것.
    INV_SCALER = dset.reverse_scaling
    SCALER = dset.apply_scaling
    NULL_VAL = None


    return (
    DATA_MODULE,
    INV_SCALER,
    SCALER,
    NULL_VAL,
    PLOT_VAR_IDXS,
    PLOT_VAR_NAMES,
    PAD_VAL,
)


* pytorch lightning callback
    - 마지막 epoch 체크 포인트 저장이 아니라, 매 epoch마다 저장하는 등
    - 세부적으로 저장할 때 쓰는 모듈임. : ModelCheckpoint
    - wandb : wieghts & Biases
    - time_mask_loss

In [5]:
def create_callbacks(config, save_dir):
    filename = f"{config.run_name}_" + str(uuid.uuid1()).split("-")[0]
    model_ckpt_dir = os.path.join(save_dir, filename)
    config.model_ckpt_dir = model_ckpt_dir
    
    
    saving = pl.callbacks.ModelCheckpoint(
        dirpath=model_ckpt_dir,
        monitor="val/loss",
        mode="min",
        filename=f"{config.run_name}" + "{epoch:02d}",
        save_top_k=1,
        auto_insert_metric_name=True,
    )
    callbacks = [saving]

    if not config.no_earlystopping:
        callbacks.append(
            pl.callbacks.early_stopping.EarlyStopping(
                monitor="val/loss",
                patience=config.patience,
            )
        )

    if config.wandb:
        callbacks.append(pl.callbacks.LearningRateMonitor())

    if config.model == "lstm":
        callbacks.append(
            stf.callbacks.TeacherForcingAnnealCallback(
                start=config.teacher_forcing_start,
                end=config.teacher_forcing_end,
                steps=config.teacher_forcing_anneal_steps,
            )
        )
    if config.time_mask_loss:
        callbacks.append(
            stf.callbacks.TimeMaskedLossCallback(
                start=config.time_mask_start,
                end=config.time_mask_end,
                steps=config.time_mask_anneal_steps,
            )
        )
    return callbacks
    

### main

** Wandb**  :MLOps Tools로 모델 학습 추적을 하고, 더 나은 모델을 빨리 만들어주는 머신러닝 툴이다.
* 실험과정 Tracking, 시각화 툴임.
* Neptune AI랑 비슷한 결이라고 생각하면 될듯!

이거 없이 돌린는 코드만 짜면 될듯

In [6]:
def main(args):
    log_dir = os.getenv("STF_LOG_DIR")
    if log_dir is None:
        log_dir = "./data/STF_LOG_DIR"
        print(
            "Using default wandb log dir path of ./data/STF_LOG_DIR. This can be adjusted with the environment variable `STF_LOG_DIR`"
        )
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)


    #여기는 그냥 주석 처리 혹은 wandb = False 처리
    # if args.wandb:
    #     import wandb

    #     project = os.getenv("STF_WANDB_PROJ")
    #     entity = os.getenv("STF_WANDB_ACCT")
    #     assert (
    #         project is not None and entity is not None
    #     ), "Please set environment variables `STF_WANDB_ACCT` and `STF_WANDB_PROJ` with \n\
    #         your wandb user/organization name and project title, respectively."
    #     experiment = wandb.init(
    #         project=project,
    #         entity=entity,
    #         config=args,
    #         dir=log_dir,
    #         reinit=True,
    #     )
    #     config = wandb.config
    #     wandb.run.name = args.run_name
    #     wandb.run.save()
    #     logger = pl.loggers.WandbLogger(
    #         experiment=experiment,
    #         save_dir=log_dir,
    #     )

    # Dset
    (
        data_module,
        inv_scaler,
        scaler,
        null_val,
        plot_var_idxs,
        plot_var_names,
        pad_val,
    ) = create_dset(args)



    # Model
    args.null_value = null_val
    args.pad_value = pad_val
    forecaster = create_model(args)
    forecaster.set_inv_scaler(inv_scaler)
    forecaster.set_scaler(scaler)
    forecaster.set_null_value(null_val)

    # Callbacks
    callbacks = create_callbacks(args, save_dir=log_dir)
    test_samples = next(iter(data_module.test_dataloader())) #한 sample만 가져오기


    #plotting
    if args.wandb and args.plot:
        callbacks.append(
            stf.plot.PredictionPlotterCallback(
                test_samples,
                var_idxs=plot_var_idxs,
                var_names=plot_var_names,
                pad_val=pad_val,
                total_samples=min(args.plot_samples, args.batch_size),
            )
        )

    # if args.wandb and args.dset in ["mnist", "cifar"] and args.plot:
    #     callbacks.append(
    #         stf.plot.ImageCompletionCallback(
    #             test_samples,
    #             total_samples=min(16, args.batch_size),
    #             mode="left-right" if config.dset == "mnist" else "flat",
    #         )
    #     )

    # if args.wandb and args.dset == "copy" and args.plot:
    #     callbacks.append(
    #         stf.plot.CopyTaskCallback(
    #             test_samples,
    #             total_samples=min(16, args.batch_size),
    #         )
    #     )

    if args.wandb and args.model == "spacetimeformer" and args.attn_plot:

        callbacks.append(
            stf.plot.AttentionMatrixCallback(
                test_samples,
                layer=0,
                total_samples=min(16, args.batch_size),
            )
        )

    # if args.wandb:
    #     config.update(args)
    #     logger.log_hyperparams(config)

    if args.val_check_interval <= 1.0:
        val_control = {"val_check_interval": args.val_check_interval}
    else:
        val_control = {"check_val_every_n_epoch": int(args.val_check_interval)}


    #pytorch lighting의 모델 정의 부분
    trainer = pl.Trainer(
        gpus=args.gpus,
        callbacks=callbacks,
        # logger=logger if args.wandb else None,
        accelerator="dp",
        gradient_clip_val=args.grad_clip_norm,
        gradient_clip_algorithm="norm",
        overfit_batches=20 if args.debug else 0,
        accumulate_grad_batches=args.accumulate,
        sync_batchnorm=True,
        limit_val_batches=args.limit_val_batches,
        **val_control,
    )

    # Train
    trainer.fit(forecaster, datamodule=data_module)

    # Test
    trainer.test(datamodule=data_module, ckpt_path="best")

    # Predict (only here as a demo and test)
    # forecaster.to("cuda")
    # xc, yc, xt, _ = test_samples
    # yt_pred = forecaster.predict(xc, yc, xt)

    # if args.wandb:
    #     experiment.finish()


**main run**

In [6]:
#parsing args
parser = create_parser()
args = "--model spacetimeformer --dset solar_energy --context_points 168 --target_points 24 --d_model 100 --d_ff 400 --enc_layers 5 --dec_layers 5 --l2_coeff 1e-3 --dropout_ff .2 --dropout_emb .1 --d_qk 20 --d_v 20 --n_heads 6 --run_name spatiotemporal_al_solar --batch_size 32 --class_loss_imp 0 --initial_downsample_convs 1 --decay_factor .8 --warmup_steps 1000".split()
args = parser.parse_args(args)
args

Namespace(accumulate=1, activation='gelu', attn_factor=5, attn_plot=False, attn_time_windows=1, base_lr=0.0005, batch_size=32, class_loss_imp=0.0, context_points=168, d_ff=400, d_model=100, d_qk=20, d_v=20, data_path='auto', debug=False, dec_layers=5, decay_factor=0.8, dropout_attn_matrix=0.0, dropout_attn_out=0.0, dropout_emb=0.1, dropout_ff=0.2, dropout_qkv=0.0, dset='solar_energy', embed_method='spatio-temporal', enc_layers=5, global_cross_attn='performer', global_self_attn='performer', gpus=None, grad_clip_norm=0, init_lr=1e-10, initial_downsample_convs=1, intermediate_downsample_convs=0, l2_coeff=0.001, learning_rate=0.0001, limit_val_batches=1.0, linear_shared_weights=False, linear_window=0, local_cross_attn='performer', local_self_attn='performer', loss='mse', model='spacetimeformer', n_heads=6, no_earlystopping=False, no_final_norm=False, no_given=False, no_space=False, no_time=False, no_val=False, norm='batch', overfit=False, patience=5, performer_kernel='relu', performer_redr

In [7]:
log_dir = os.getenv("STF_LOG_DIR")
if log_dir is None:
    log_dir = "./data/STF_LOG_DIR"
print(
    "Using default wandb log dir path of ./data/STF_LOG_DIR. This can be adjusted with the environment variable `STF_LOG_DIR`"
)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

Using default wandb log dir path of ./data/STF_LOG_DIR. This can be adjusted with the environment variable `STF_LOG_DIR`


**in create_dset**

1. CSVTimeSeires : csv time series : 전처리하는 class
2. CSVTorchDset : dataset -> torch dataset 만드는 class
3. Datamodule : Dataloader : from pl.lightning

In [8]:
# (
#     data_module,
#     inv_scaler,
#     scaler,
#     null_val,
#     plot_var_idxs,
#     plot_var_names,
#     pad_val,
#     ) = create_dset(args)


time_col_name = "Datetime"
data_path = args.data_path
time_features = ["year", "month", "day", "weekday", "hour", "minute"]

data_path = "./data/solar_AL_converted.csv"
target_cols = [str(i) for i in range(137)]

In [9]:
dset = stf.data.CSVTimeSeries(
        data_path=data_path,
        target_cols=target_cols,
        ignore_cols="all",
        time_col_name=time_col_name,
        time_features=time_features,
        val_split=0.2,
        test_split=0.2,
    )


In [11]:
dset.train_data.head()

#time col + target col : variables -> normalization sacaling


Unnamed: 0,Datetime,0,1,2,3,4,5,6,7,8,...,133,134,135,136,Year,Month,Day,Weekday,Hour,Minute
0,2006-01-01 00:00:00,-0.802494,-0.713569,-0.698395,-0.700657,-0.791345,-0.696641,-0.722225,-0.698887,-0.701288,...,-0.713636,-0.701126,-0.786394,-0.694511,0.0,-1.0,-1.0,1.0,-1.0,-1.0
1,2006-01-01 00:10:00,-0.802494,-0.713569,-0.698395,-0.700657,-0.791345,-0.696641,-0.722225,-0.698887,-0.701288,...,-0.713636,-0.701126,-0.786394,-0.694511,0.0,-1.0,-1.0,1.0,-1.0,-0.661017
2,2006-01-01 00:20:00,-0.802494,-0.713569,-0.698395,-0.700657,-0.791345,-0.696641,-0.722225,-0.698887,-0.701288,...,-0.713636,-0.701126,-0.786394,-0.694511,0.0,-1.0,-1.0,1.0,-1.0,-0.322034
3,2006-01-01 00:30:00,-0.802494,-0.713569,-0.698395,-0.700657,-0.791345,-0.696641,-0.722225,-0.698887,-0.701288,...,-0.713636,-0.701126,-0.786394,-0.694511,0.0,-1.0,-1.0,1.0,-1.0,0.016949
4,2006-01-01 00:40:00,-0.802494,-0.713569,-0.698395,-0.700657,-0.791345,-0.696641,-0.722225,-0.698887,-0.701288,...,-0.713636,-0.701126,-0.786394,-0.694511,0.0,-1.0,-1.0,1.0,-1.0,0.355932


**dataloader 만들기**

In [12]:
DATA_MODULE = stf.data.DataModule(
        datasetCls=stf.data.CSVTorchDset,
        dataset_kwargs={
            "csv_time_series": dset,
            "context_points": args.context_points,
            "target_points": args.target_points,
            "time_resolution": args.time_resolution,
        },
        batch_size=args.batch_size,
        workers=args.workers,
        overfit=args.overfit, # arg는 from global parser
    )

In [16]:
sample = next(iter(DATA_MODULE.train_dataloader()))

In [20]:
print(args.batch_size)
print(len(sample)) # elf._torch(ctxt_x, ctxt_y, trgt_x, trgt_y)
print(sample[0].shape) #batch size X : context_len + pred_len : ???

32
4
torch.Size([32, 168, 6])


In [33]:
sample[0][1]

tensor([[ 0.9333, -0.4783, -0.6610, -0.2727, -0.6667,  0.0000],
        [ 0.9333, -0.4783, -0.3220, -0.2727, -0.6667,  0.0000],
        [ 0.9333, -0.4783,  0.0169, -0.2727, -0.6667,  0.0000],
        ...,
        [ 1.0000, -0.2174,  0.3559, -0.2727, -0.3333,  0.0000],
        [ 1.0000, -0.2174,  0.6949, -0.2727, -0.3333,  0.0000],
        [ 1.0000, -0.1304, -1.0000, -0.2727, -0.3333,  0.0000]])

In [21]:
# INV_SCALER = lambda x: x
# SCALER = lambda x: x
NULL_VAL = None
PLOT_VAR_IDXS = None
PLOT_VAR_NAMES = None
PAD_VAL = None




##전처리 하는 것.
INV_SCALER = dset.reverse_scaling
SCALER = dset.apply_scaling
NULL_VAL = None




data_module = DATA_MODULE
inv_scaler = INV_SCALER
scaler = SCALER
null_val =NULL_VAL
plot_var_idxs = PLOT_VAR_IDXS
plot_var_names = PLOT_VAR_NAMES
pad_val= PAD_VAL


**model**

In [26]:
    # Model
args.null_value = null_val
args.pad_value = pad_val
# forecaster = create_model(args)

In [27]:
#as to solar energy dataset
x_dim, yc_dim, yt_dim = None, None, None
x_dim = 6
yc_dim = 137
yt_dim = 137


#config.model == "spacetimeformer":

#setting for target length and context length
if hasattr(args, "context_points") and hasattr(args, "target_points"):
    max_seq_len = args.context_points + args.target_points
elif hasattr(args, "max_len"):
    max_seq_len = args.max_len
else:
    raise ValueError("Undefined max_seq_len")


config = args

forecaster = stf.spacetimeformer_model.Spacetimeformer_Forecaster(
    d_x=x_dim,
    d_yc=yc_dim,
    d_yt=yt_dim,
    max_seq_len=max_seq_len,
    start_token_len=config.start_token_len,


    attn_factor=config.attn_factor,
    d_model=config.d_model,
    d_queries_keys=config.d_qk,
    d_values=config.d_v,
    n_heads=config.n_heads,
    e_layers=config.enc_layers,
    d_layers=config.dec_layers,
    d_ff=config.d_ff,
    dropout_emb=config.dropout_emb,
    dropout_attn_out=config.dropout_attn_out,
    dropout_attn_matrix=config.dropout_attn_matrix,
    dropout_qkv=config.dropout_qkv,
    dropout_ff=config.dropout_ff,
    pos_emb_type=config.pos_emb_type,
    use_final_norm=not config.no_final_norm,
    global_self_attn=config.global_self_attn,
    local_self_attn=config.local_self_attn,
    global_cross_attn=config.global_cross_attn,
    local_cross_attn=config.local_cross_attn,
    performer_kernel=config.performer_kernel,
    performer_redraw_interval=config.performer_redraw_interval,
    attn_time_windows=config.attn_time_windows,
    use_shifted_time_windows=config.use_shifted_time_windows,
    norm=config.norm,
    activation=config.activation,
    init_lr=config.init_lr,
    base_lr=config.base_lr,
    warmup_steps=config.warmup_steps,
    decay_factor=config.decay_factor,
    initial_downsample_convs=config.initial_downsample_convs,
    intermediate_downsample_convs=config.intermediate_downsample_convs,
    embed_method=config.embed_method,
    l2_coeff=config.l2_coeff,
    loss=config.loss,
    class_loss_imp=config.class_loss_imp,
    recon_loss_imp=config.recon_loss_imp,
    time_emb_dim=config.time_emb_dim,
    null_value=config.null_value,
    pad_value=config.pad_value,
    linear_window=config.linear_window,
    use_revin=config.use_revin,
    linear_shared_weights=config.linear_shared_weights,
    use_seasonal_decomp=config.use_seasonal_decomp,
    use_val=not config.no_val,
    use_time=not config.no_time,
    use_space=not config.no_space,
    use_given=not config.no_given,
    recon_mask_skip_all=config.recon_mask_skip_all,
    recon_mask_max_seq_len=config.recon_mask_max_seq_len,
    recon_mask_drop_seq=config.recon_mask_drop_seq,
    recon_mask_drop_standard=config.recon_mask_drop_standard,
    recon_mask_drop_full=config.recon_mask_drop_full,
)

Forecaster
	L2: 0.001
	Linear Window: 0
	Linear Shared Weights: False
	RevIN: False
	Decomposition: False
GlobalSelfAttn: AttentionLayer(
  (inner_attention): PerformerAttention(
    (kernel_fn): ReLU()
  )
  (query_projection): Linear(in_features=100, out_features=120, bias=True)
  (key_projection): Linear(in_features=100, out_features=120, bias=True)
  (value_projection): Linear(in_features=100, out_features=120, bias=True)
  (out_projection): Linear(in_features=120, out_features=100, bias=True)
  (dropout_qkv): Dropout(p=0.0, inplace=False)
)
GlobalCrossAttn: AttentionLayer(
  (inner_attention): PerformerAttention(
    (kernel_fn): ReLU()
  )
  (query_projection): Linear(in_features=100, out_features=120, bias=True)
  (key_projection): Linear(in_features=100, out_features=120, bias=True)
  (value_projection): Linear(in_features=100, out_features=120, bias=True)
  (out_projection): Linear(in_features=120, out_features=100, bias=True)
  (dropout_qkv): Dropout(p=0.0, inplace=False)
)
L

In [None]:
forecaster.set_inv_scaler(inv_scaler) # scaler assigment in model
forecaster.set_scaler(scaler)
forecaster.set_null_value(null_val)

In [28]:
# Callbacks
callbacks = create_callbacks(args, save_dir=log_dir)
test_samples = next(iter(data_module.test_dataloader())) #한 sample만 가져오기

In [29]:
print(args.batch_size)
print(len(test_samples)) # elf._torch(ctxt_x, ctxt_y, trgt_x, trgt_y)
print(test_samples[0].shape) #batch size X : context_len + pred_len : time_emb_dim

32
4
torch.Size([32, 168, 6])


In [31]:
print(args.val_check_interval)

#val_check_interval을 int로 변환
if args.val_check_interval <= 1.0:
        val_control = {"val_check_interval": args.val_check_interval}
else:
    val_control = {"check_val_every_n_epoch": int(args.val_check_interval)}

print(val_control)

1.0
{'val_check_interval': 1.0}


train epoch (전체 학습데이터)의 일부(여기서는 1이니까 전체)를 할때마다 validation

How often to check the validation set. Pass a float in the range [0.0, 1.0] to check after a fraction of the training epoch. Pass an int to check after a fixed number of training batches. An int value can only be higher than the number of training batches when check_val_every_n_epoch=None, which validates after every N training batches across epochs or during iteration-based training. Default: 1.0

In [34]:
#pytorch lighting의 모델 정의 부분
trainer = pl.Trainer(
    gpus=args.gpus,
    callbacks=callbacks,
    # logger=logger if args.wandb else None,
    # accelerator="dp", # ->  gpu가 없어서...
    
    gradient_clip_val=args.grad_clip_norm,
    gradient_clip_algorithm="norm",
    overfit_batches=20 if args.debug else 0,
    accumulate_grad_batches=args.accumulate,
    sync_batchnorm=True,
    limit_val_batches=args.limit_val_batches,
    **val_control,
)

  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


In [37]:
args

Namespace(accumulate=1, activation='gelu', attn_factor=5, attn_plot=False, attn_time_windows=1, base_lr=0.0005, batch_size=32, class_loss_imp=0.0, context_points=168, d_ff=400, d_model=100, d_qk=20, d_v=20, data_path='auto', debug=False, dec_layers=5, decay_factor=0.8, dropout_attn_matrix=0.0, dropout_attn_out=0.0, dropout_emb=0.1, dropout_ff=0.2, dropout_qkv=0.0, dset='solar_energy', embed_method='spatio-temporal', enc_layers=5, global_cross_attn='performer', global_self_attn='performer', gpus=None, grad_clip_norm=0, init_lr=1e-10, initial_downsample_convs=1, intermediate_downsample_convs=0, l2_coeff=0.001, learning_rate=0.0001, limit_val_batches=1.0, linear_shared_weights=False, linear_window=0, local_cross_attn='performer', local_self_attn='performer', loss='mse', model='spacetimeformer', model_ckpt_dir='./data/STF_LOG_DIR/spatiotemporal_al_solar_7f10f866', n_heads=6, no_earlystopping=False, no_final_norm=False, no_given=False, no_space=False, no_time=False, no_val=False, norm='batc

In [39]:
print(args.target_points)
print(args.context_points)
print(args.time_emb_dim)

24
168
6


**training**

In [None]:
# Train
trainer.fit(forecaster, datamodule=data_module)

In [None]:
# Test
trainer.test(datamodule=data_module, ckpt_path="best")

# Predict (only here as a demo and test)
forecaster.to("cuda")
xc, yc, xt, _ = test_samples
yt_pred = forecaster.predict(xc, yc, xt)

In [None]:
# if __name__ == "__main__":
#     # CLI
#     parser = create_parser()
#     args = parser.parse_args()

#     for trial in range(args.trials): # 여기서 지역 데이터를 계속 iteration 을 돌릴 것.
#         main(args)

# Dataset and Dataloader