In [2]:
import lightning as L
import os
from batchgenerators.utilities.file_and_folder_operations import load_json
from yucca.pipeline.managers.YuccaManager import YuccaManager
from yucca.paths import yucca_raw_data, yucca_preprocessed_data, yucca_models
from yucca.pipeline.configuration.configure_task import TaskConfig
from yucca.pipeline.configuration.configure_paths import get_path_config
from yucca.pipeline.configuration.configure_callbacks import get_callback_config
from yucca.pipeline.configuration.split_data import get_split_config
from yucca.pipeline.configuration.configure_input_dims import InputDimensionsConfig
from yucca.data.augmentation.YuccaAugmentationComposer import YuccaAugmentationComposer
from yucca.data.augmentation.augmentation_presets import generic
from yucca.lightning_modules.YuccaLightningModule import YuccaLightningModule
from yucca.data.data_modules.YuccaDataModule import YuccaDataModule

Set some variables that we'll need

In [3]:
config = {
    "batch_size": 2,
    "dims": "2D",
    "deep_supervision": False,
    "experiment": "default",
    "learning_rate": 1e-3,
    "loss_fn": "DiceCE",
    "model_name": "TinyUNet",
    "momentum": 0.99,
    "num_classes": 3,
    "num_modalities": 1,
    "patch_size": (32, 32),
    "plans_name": "demo",
    "plans": None,
    "split_idx": 0,
    "split_method": "kfold",
    "split_param": 5,
    "task": "Task001_OASIS",
    "task_type": "segmentation",
}

In [4]:
input_dims_config = InputDimensionsConfig(
    batch_size=config.get("batch_size"), patch_size=config.get("patch_size"), num_modalities=config.get("num_modalitites")
)
task_config = TaskConfig(
    task=config.get("task"),
    continue_from_most_recent=True,
    experiment=config.get("experiment"),
    manager_name="",
    model_dimensions=config.get("dims"),
    model_name=config.get("model_name"),
    patch_based_training=True,
    planner_name=config.get("plans_name"),
    split_idx=config.get("split_idx"),
    split_method=config.get("split_method"),
    split_param=config.get("split_param"),
)

path_config = get_path_config(task_config=task_config)

split_config = get_split_config(method=task_config.split_method, param=task_config.split_param, path_config=path_config)

callback_config = get_callback_config(
    save_dir=path_config.save_dir,
    version_dir=path_config.version_dir,
    experiment=task_config.experiment,
    version=path_config.version,
    enable_logging=False,
)

augmenter = YuccaAugmentationComposer(
    deep_supervision=config.get("deep_supervision"),
    patch_size=input_dims_config.patch_size,
    is_2D=True if config.get("dims") == "2D" else False,
    parameter_dict=generic,
    task_type_preset=config.get("task_type"),
)


model_module = YuccaLightningModule(
    config=config | task_config.lm_hparams() | path_config.lm_hparams() | callback_config.lm_hparams(),
    deep_supervision=config.get("deep_supervision"),
    learning_rate=config.get("learning_rate"),
    loss_fn=config.get("loss_fn"),
    momentum=config.get("momentum"),
)

data_module = YuccaDataModule(
    composed_train_transforms=augmenter.train_transforms,
    composed_val_transforms=augmenter.val_transforms,
    input_dims_config=input_dims_config,
    train_data_dir=path_config.train_data_dir,
    split_idx=task_config.split_idx,
    splits_config=split_config,
    task_type=config.get("task_type"),
)

INFO:root:YuccaLightningModule initialized with the following config: {'batch_size': 2, 'dims': '2D', 'deep_supervision': False, 'experiment': 'default', 'learning_rate': 0.001, 'loss_fn': 'DiceCE', 'model_name': 'TinyUNet', 'momentum': 0.99, 'num_classes': 3, 'num_modalities': 1, 'patch_size': (32, 32), 'plans_name': 'demo', 'plans': None, 'split_idx': 0, 'split_method': 'kfold', 'split_param': 5, 'task': 'Task001_OASIS', 'task_type': 'segmentation', 'continue_from_most_recent': True, 'manager_name': '', 'model_dimensions': '2D', 'patch_based_training': True, 'planner_name': 'demo', 'plans_path': '/Users/zcr545/Desktop/Projects/repos/yucca_data/preprocessed/Task001_OASIS/demo/demo_plans.json', 'save_dir': '/Users/zcr545/Desktop/Projects/repos/yucca_data/models/Task001_OASIS/TinyUNet__2D/__demo/default/kfold_5_fold_0', 'train_data_dir': '/Users/zcr545/Desktop/Projects/repos/yucca_data/preprocessed/Task001_OASIS/demo', 'version_dir': '/Users/zcr545/Desktop/Projects/repos/yucca_data/mode

Composing Transforms


  from .autonotebook import tqdm as notebook_tqdm
INFO:root:Using 9 workers
INFO:root:Using dataset class: <class 'yucca.data.datasets.YuccaDataset.YuccaTrainDataset'> for train/val and <class 'yucca.data.datasets.YuccaDataset.YuccaTestDataset'> for inference


In [5]:
trainer = L.Trainer(
    callbacks=callback_config.callbacks,
    default_root_dir=path_config.save_dir,
    limit_train_batches=2,
    limit_val_batches=2,
    log_every_n_steps=2,
    logger=callback_config.loggers,
    precision="32",
    profiler=callback_config.profiler,
    enable_progress_bar=True,
    max_epochs=2,
    accelerator="cpu",
)


trainer.fit(
    model=model_module,
    datamodule=data_module,
    ckpt_path="last",
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/zcr545/miniconda3/envs/yuccaenv/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/zcr545/miniconda3/envs/yuccaenv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path="last") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.
INFO:root:Setting up data for stage: TrainerFn.FITTING
INFO:root:Training on samples: ['/Users/zcr545/Desktop/Projects/repos/yucca_data/preprocessed/Task001_OASIS/demo/OASIS_1000', '/Users/zcr545/Desktop/Projects/repos/yucca_data/preprocessed/Task001_OASIS/demo/OASIS_1001', '/Users/zcr545/Desktop/Projects/repos/yucca_data/preprocessed/Task001_OASIS/demo/OASIS_1002', '

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/zcr545/miniconda3/envs/yuccaenv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]torch.Size([2, 3, 32, 32]) torch.Size([2, 1, 32, 32])
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 17.28it/s]torch.Size([2, 3, 32, 32]) torch.Size([2, 1, 32, 32])
                                                                           

INFO:root:Starting training with data from: /Users/zcr545/Desktop/Projects/repos/yucca_data/preprocessed/Task001_OASIS/demo




/Users/zcr545/miniconda3/envs/yuccaenv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s] torch.Size([2, 3, 32, 32]) torch.Size([2, 1, 32, 32])
Epoch 0:  50%|█████     | 1/2 [00:03<00:03,  0.30it/s, v_num=0]torch.Size([2, 3, 32, 32]) torch.Size([2, 1, 32, 32])
Epoch 0: 100%|██████████| 2/2 [00:03<00:00,  0.59it/s, v_num=0]torch.Size([2, 3, 32, 32]) torch.Size([2, 1, 32, 32])
torch.Size([2, 3, 32, 32]) torch.Size([2, 1, 32, 32])
Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s, v_num=0]        torch.Size([2, 3, 32, 32]) torch.Size([2, 1, 32, 32])
Epoch 1:  50%|█████     | 1/2 [00:35<00:35,  0.03it/s, v_num=0]torch.Size([2, 3, 32, 32]) torch.Size([2, 1, 32, 32])
Epoch 1: 100%|██████████| 2/2 [00:35<00:00,  0.06it/s, v_num=0]torch.Size([2, 3, 32, 32]) torch.Size([2, 1, 32, 32])
torch.Size([2, 3, 32, 32]) torch.Size([2, 1, 32, 32])
Epoch 1: 100%|██████████| 2/2 [01:20<00:00,  0.02it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 2/2 [01:20<00:00,  0.02it/s, v_num=0]
