In [9]:
import os
import numpy as np
import torch


from lightning.pytorch.trainer import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from dsanet_model import DSANet
from dsanet_model import add_model_config
from data import MTSFDataset
from torch.utils.data import DataLoader
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import warnings

warnings.filterwarnings("ignore")

SEED = 7
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)

root_dir = "/root/myDL/DSANet"
log_dir = os.path.join(root_dir, "dsanet_logs")

config = add_model_config(root_dir)

model = DSANet(config)
# print(model)
tb_logger = TensorBoardLogger(name="electricity", save_dir=config["log_dir"], version=2)
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.001, patience=3, verbose=False, mode="max")


trainer = Trainer(
    devices=1,
    logger=tb_logger,
    max_epochs=20,
    callbacks=[early_stop_callback]
)

trainer.fit(model)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type                          | Params | Mode 
----------------------------------------------------------------------
0 | sgsf        | Single_Global_SelfAttn_Module | 18.9 M | train
1 | slsf        | Single_Local_SelfAttn_Module  | 18.9 M | train
2 | ar          | AR                            | 65     | train
3 | W_output1   | Linear                        | 65     | train
4 | dropout     | Dropout                       | 0      | train
5 | active_func | Tanh                          | 0      | train
----------------------------------------------------------------------
37.9 M    Trainable params
0         Non-trainable params
37.9 M    Total params
151.590   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [4]:
trainer.test()

Restoring states from the checkpoint path at /root/myDL/DSANet/lightning_logs/electricity/version_1/checkpoints/epoch=19-step=8560.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /root/myDL/DSANet/lightning_logs/electricity/version_1/checkpoints/epoch=19-step=8560.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
          CORR                      nan
          NMSE             0.025881389155983925
          RRSE              0.11289846152067184
        test_loss                1183922.0
     test_loss_epoch            1184752.75
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss_epoch': 1184752.75,
  'test_loss': 1183922.0,
  'RRSE': 0.11289846152067184,
  'CORR': nan,
  'NMSE': 0.025881389155983925}]

In [8]:
trainer.early_stopping_callback

In [None]:
# 加载继续训练
import lightning.pytorch as pl

csv_logger = pl.loggers.CSVLogger(save_dir="lightning_logs/csi", name="version_0")
trainer = pl.Trainer(
    logger = csv_logger,
    callbacks=[pl.callbacks.ModelCheckpoint(every_n_train_steps=60, save_top_k=-1)],
)

model = DSANet(config)
ckp1 = "lightning_logs/csi/version_3/checkpoints/epoch=19-step=4280.ckpt"
ckp2 = "lightning_logs/csi/version_4/checkpoints/epoch=39-step=8560.ckpt"
ckp3 = 'lightning_logs/csi/version_0/checkpoints/epoch=4-step=1070.ckpt'
trainer.test(
    model=model,
    ckpt_path=ckp3,
)

In [None]:
PATH = "lightning_logs/csi/version_4/checkpoints/epoch=39-step=8560.ckpt"
model = DSANet.load_from_checkpoint(PATH, config=config)

print(model.config["learning_rate"])
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(1,64, 32).to(device)
with torch.no_grad():
    y_hat = model(x)
print(y_hat)

In [None]:
x.shape

In [None]:
trainer = pl.Trainer(
    callbacks=[pl.callbacks.ModelCheckpoint(every_n_train_steps=60, save_top_k=-1)],
)

model = DSANet(config)


In [None]:
import tensorboardX