In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

from deep_fib.sci_net import SCIBlockCfg, SCINet
from deep_fib.data import DeepFIBDataset
from deep_fib.core import train_step, test_step

from utils.data import Marconi100Dataset, get_train_test_split
from utils.training import training_loop
from utils.summary import SummaryWriter

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

horizon = 1024
n_sample_per_data = 5

batch_size = 32
num_workers = 2

num_encoder_levels=2

log_dir = "./trash"
lr = 1e-3
num_epochs = 3
step_size = 2

hidden = None
block_cfg = SCIBlockCfg(
    input_dim=460,
    hidden_size=4,
    kernel_size=3,
    dropout=0.5,
)

cuda


In [14]:
train, test = get_train_test_split(0.1, 42)

In [15]:
# train = train[:len(train)//10]
# test = test[:len(test)//10]
# len(train), len(test)

In [16]:
dataset_train = DeepFIBDataset(Marconi100Dataset(train, normalize=True), horizon, n_sample_per_data)
dataset_test = DeepFIBDataset(Marconi100Dataset(test, normalize=True), horizon, n_sample_per_data)
len(dataset_train), len(dataset_test)

Loading: 100%|██████████| 224/224 [01:16<00:00,  2.93it/s]
Loading: 100%|██████████| 25/25 [00:09<00:00,  2.63it/s]


(28220, 3235)

In [7]:
train_loader = DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers)

In [18]:
model = SCINet(
    output_len=horizon,
    input_len=horizon,
    num_encoder_levels=num_encoder_levels,
    hidden_decoder_sizes=hidden,
    block_config=block_cfg,
).float().to(device)

optim = Adam(model.parameters(), lr=lr)
lr_sched = StepLR(optim, step_size)

with SummaryWriter(log_dir) as writer:
    training_loop(
        model=model,
        train_step_function=train_step,
        test_step_function=test_step,
        num_epochs=num_epochs,
        train_dataloader=train_loader,
        test_dataloader=test_loader,
        device=device,
        optimizer=optim,
        lr_scheduler=lr_sched,
        writer=writer,
        save_path=log_dir + "/models"
    )

                                                          

Epoch 0 - train_loss = 151.315 - test_loss = 137.507 - lr = 1.00e-03


                                                          

Epoch 1 - train_loss = 109.902 - test_loss = 134.273 - lr = 1.00e-03


                                                          

Epoch 2 - train_loss = 98.074 - test_loss = 126.583 - lr = 1.00e-04


In [22]:
%load_ext tensorboard
%tensorboard --logdir=trash

Reusing TensorBoard on port 6006 (pid 17280), started 0:01:38 ago. (Use '!kill 17280' to kill it.)