In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

from deep_fib.sci_net import SCIBlockCfg, SCINet
from deep_fib.data import DeepFIBDataset
from deep_fib.core import train_step, test_step

from utils.data import Marconi100Dataset, get_train_test_split
from utils.training import training_loop
from utils.summary import SummaryWriter

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

horizon = 1024
n_sample_per_data = 2

batch_size = 4
num_workers = 0

num_encoder_levels=2

log_dir = "./trash"
lr = 1e-3
num_epochs = 3
step_size = 1

hidden = None
block_cfg = SCIBlockCfg(
    input_dim=460,
    hidden_size=2,
    kernel_size=3,
    dropout=0.5,
)

In [5]:
train, test = get_train_test_split(0.5, 42)

In [6]:
train = train[:1]
test = test[:1]

In [7]:
dataset_train = DeepFIBDataset(Marconi100Dataset(train, normalize=True), horizon, n_sample_per_data)
dataset_test = DeepFIBDataset(Marconi100Dataset(test, normalize=True), horizon, n_sample_per_data)
len(dataset_train), len(dataset_test)

Loading: 100%|██████████| 1/1 [00:00<00:00,  1.14it/s]
Loading: 100%|██████████| 1/1 [00:00<00:00,  1.39it/s]


(52, 50)

In [8]:
train_loader = DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers)

In [14]:
model = SCINet(
    output_len=horizon,
    input_len=horizon,
    num_encoder_levels=num_encoder_levels,
    hidden_decoder_sizes=hidden,
    block_config=block_cfg,
).float().to(device)

optim = Adam(model.parameters(), lr=lr)
lr_sched = StepLR(optim, step_size)

with SummaryWriter(log_dir) as writer:
    training_loop(
        model=model,
        train_step_function=train_step,
        test_step_function=test_step,
        num_epochs=num_epochs,
        train_dataloader=train_loader,
        test_dataloader=test_loader,
        device=device,
        optimizer=optim,
        lr_scheduler=lr_sched,
        writer=writer,
        save_path=log_dir + "/models"
    )

                                                        

Epoch 0 - train_loss = 326.656 - test_loss = 248.665 - lr = 1.00e-03


                                                        

Epoch 1 - train_loss = 252.862 - test_loss = 238.825 - lr = 1.00e-04


                                                        

Epoch 2 - train_loss = 246.549 - test_loss = 237.407 - lr = 1.00e-05
