In [None]:
import os
import yaml
import torch
import importlib
import numpy as np
import matplotlib.pyplot as plt
from physioex.physioex.data import PhysioExDataModule
from src.attribution.flextime import Filterbank

In [None]:
datamodule = PhysioExDataModule(
    datasets=["sleepedf"],     # list of datasets to be used
    batch_size=64,             # batch size for the DataLoader
    preprocessing="raw",       # preprocessing method
    selected_channels=["EEG"], # channels to be selected
    sequence_length=21,        # length of the sequence
    data_folder="./data",      # path to the data folder
)

# get the DataLoaders
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
test_loader = datamodule.test_dataloader()

print(len(train_loader)) # number of batches in the training set
print(len(val_loader)) # number of batches in the validation set
print(len(test_loader)) # number of batches in the test set

checkpoint_path = "./model/checkpoint/"

In [None]:
# with open("./physioex/physioex/train/networks/config/chambon2018.yaml", "r") as file:
with open("./config.yaml", "r") as file:
    config = yaml.safe_load(file)

print(config)
network_config = config["model_config"]

print(network_config)

# load the loss function 
loss_package, loss_class = network_config["loss_call"].split(":")
model_loss = getattr(importlib.import_module(loss_package), loss_class)

print(model_loss)

# in case you provide model_name the system loads the additional model parameters from the library
if "model_name" in config:
    model_name = config["model_name"]
    # load the loss function 
    model_package, model_loader = network_config["loss_call"].split(":")
    model = getattr(importlib.import_module(model_package), model_loader)

print(model_name)

# load the model class
model_package, model_class = config["module"].split(":")
model_class = getattr(importlib.import_module(model_package), model_class)

print(model_class)