In [1]:
from utils import *
from models import *
from training import *

In [2]:
# Create the configs dictionary
configs = {

    ####### Model #######
    "model_name" : "resnet", # Model name
    "input_nc" : 1, # Number of input channels
    "output_nc" : 1, # Number of output channels
    "ngf" : 64, # Number of filters in first conv layer
    "num_blocks" : 9, # Number of residual blocks
    "norm_layer" : "BatchNorm3d", # Normalization layer
    "use_dropout" : False, # Dropout layers
    "padding_type" : "reflect", # Padding type
    
    ####### Training #######
    "n_epochs" : 100, # Number of epochs
    "loss" : "mse_loss", # Loss function
    "optimizer" : "Adam", # Optimizer
    "evaluation_metric" : "mse_loss", # Evaluation metric
    "save_best" : True, # Save best model
    "regularized" : False, # Regularization
    "vae" : False, # Variational autoencoder

    ####### Data #######
    "main_data_path" : "D:\\Brain-MINDS\\model_data", # Data path
    "training_log_path" : "D:\\tract_residuals\\logs", # Training log path
    "model_filename" : "D:\\tract_residuals\\models", # Model filename
    "batch_size" : 1, # Batch size
    "validation_batch_size" : 1, # Validation batch size
    
    ####### Parameters #######
    "initial_learning_rate" : 1e-04, # Initial learning rate
    "early_stopping_patience": 50, # Early stopping patience
    "decay_patience": 20, # Learning rate decay patience
    "decay_factor": 0.5, # Learning rate decay factor
    "min_learning_rate": 1e-08, # Minimum learning rate
    "save_last_n_models": 10, # Save last n models

    ####### Misc #######
    "skip_val" : False, # Skip validation

}

In [3]:
# Define the configuration path and save it as a .json file
config_path = os.path.join("configs", configs["model_name"] + ".json")

# Save the configuration
dump_json(configs, config_path)

In [4]:
# Load the configuration
configs = load_json(config_path)

# Define the metric to monitor based on whether we're skipping val or not
if configs["skip_val"]:
    metric_to_monitor = "loss"
else:
    metric_to_monitor = "val_loss"

# Define the groups
if configs["skip_val"]:
    groups = ("training",)
else:
    groups = ("training", "validation")

model_metrics = (configs["evaluation_metric"],)

run_pytorch_training(configs, configs["model_filename"], configs["training_log_path"], 
                     metric_to_monitor=metric_to_monitor,
                     bias=None)

Model is: ResnetEncoder
Criterion is: function
Optimizer is: Adam
b0_images shape: torch.Size([1, 256, 356, 230])
wmfod_norms shape: torch.Size([1, 128, 178, 115, 45])
residuals shape: torch.Size([1, 256, 356, 230])
injection_centers shape: torch.Size([1, 3])
b0_images shape: (256, 356, 230)
wmfod_norms shape: (128, 178, 115, 45)
residuals shape: (256, 356, 230)
