In [1]:
import argparse
import json
import os

import wandb
from monai.data import CacheDataset, DataLoader
from monai.losses import DiceLoss
from monai.metrics import DiceMetric

from csnet.transforms.default import get_default_train_transforms
from csnet.utils import get_data_dict, get_model_name
from csnet.utils.model import get_model
from csnet.utils.train import train

  from .autonotebook import tqdm as notebook_tqdm


### Specify parameters

In [2]:
data_dir = 'data/semantic_3D'
train_dir = 'train'
val_dir = 'val'
model_path = 'model'

model = 'csnet' # one of ["unet", "csnet", "csnet_orig"]

model_channels = (16, 32, 64, 128)
num_residual_units = 1


log_progress = False
wandb_project = 'CSNet Test'
wandb_key_filename = None

In [3]:
config = dict(
        epochs=30,
        batch_size=2,
        lr=0.0001,
        weight_decay=0.0005,
        factor=0.1,
        patience=2,
        roi_size=(32, 64, 64),
        model_path=model_path,
        metric_name='Dice Metric',
        model=model,
        n_channels=model_channels,
        num_res_units=num_residual_units,
        log_progress=log_progress,
        wandb_project=wandb_project,
        data_dir=data_dir,
        train_dirname=train_dir,
        val_dirname=val_dir,
        img_dirname='img',
        gt_dirname='gt'
    )
config

{'epochs': 30,
 'batch_size': 2,
 'lr': 0.0001,
 'weight_decay': 0.0005,
 'factor': 0.1,
 'patience': 2,
 'roi_size': (32, 64, 64),
 'model_path': 'model',
 'metric_name': 'Dice Metric',
 'model': 'csnet',
 'n_channels': (16, 32, 64, 128),
 'num_res_units': 1,
 'log_progress': False,
 'wandb_project': 'CSNet Test',
 'data_dir': 'data/semantic_3D',
 'train_dirname': 'train',
 'val_dirname': 'val',
 'img_dirname': 'img',
 'gt_dirname': 'gt'}

### Initialize wandb project

In [4]:
config = argparse.Namespace(**config)
if config.log_progress and wandb_key_filename is not None:
    with open(wandb_key_filename) as f:
        key = f.read()
    os.environ['WANDB_API_KEY'] = key
else:
    os.environ['WANDB_MODE'] = 'offline'

wandb.init(project=config.wandb_project, config=vars(config))

# Update model path
config.model_path = os.path.join(config.model_path, get_model_name(config.log_progress))

# Save training parameters
os.makedirs(config.model_path, exist_ok=True)
with open(os.path.join(config.model_path, 'config.json'), 'w') as f:
    json.dump(vars(config), f, indent=4)
          


2023-02-15 13:47:29,781 - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


### Setup model, loss and metric

In [5]:
net = get_model(config)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
dice_metric = DiceMetric(include_background=False, reduction="mean")

### Setup data loaders

In [6]:
# Train and validation transforms
train_tr, val_tr = get_default_train_transforms(roi_size=config.roi_size)

# Training and validation file lists
train_files = get_data_dict(os.path.join(config.data_dir, config.train_dirname, config.img_dirname),
                            os.path.join(config.data_dir, config.train_dirname, config.gt_dirname))
val_files = get_data_dict(os.path.join(config.data_dir, config.val_dirname, config.img_dirname),
                          os.path.join(config.data_dir, config.val_dirname, config.gt_dirname))

# Dataset and dataloader for training
tr_ds = CacheDataset(data=train_files, transform=train_tr, cache_rate=1, num_workers=2 * config.batch_size)
train_dl = DataLoader(tr_ds, batch_size=config.batch_size, shuffle=True, num_workers=2 * config.batch_size)

# Dataset and dataloader for validation
val_ds = CacheDataset(data=val_files, transform=val_tr, cache_rate=1.0, num_workers=2 * config.batch_size)
val_dl = DataLoader(val_ds, batch_size=config.batch_size, num_workers=2 * config.batch_size)

Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 221335.30it/s]
Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 186082.70it/s]


### Train

In [7]:
train(train_dl, val_dl, net, loss_function, dice_metric, config, log_tensorboard=True)

epoch 1 training loss: 0.6338
epoch 1 validation loss: 0.6181; Dice Metric: 0.0045
Saved new best model to: model/2023-02-15-13-47-37/best_model.pth
epoch 2 training loss: 0.5984
epoch 2 validation loss: 0.5876; Dice Metric: 0.0074
Saved new best model to: model/2023-02-15-13-47-37/best_model.pth
epoch 3 training loss: 0.5688
epoch 3 validation loss: 0.5589; Dice Metric: 0.0640
Saved new best model to: model/2023-02-15-13-47-37/best_model.pth
epoch 4 training loss: 0.5413
epoch 4 validation loss: 0.5359; Dice Metric: 0.1723
Saved new best model to: model/2023-02-15-13-47-37/best_model.pth
epoch 5 training loss: 0.5228
epoch 5 validation loss: 0.5205; Dice Metric: 0.1628
epoch 6 training loss: 0.5092
epoch 6 validation loss: 0.5074; Dice Metric: 0.1899
Saved new best model to: model/2023-02-15-13-47-37/best_model.pth
epoch 7 training loss: 0.4961
epoch 7 validation loss: 0.4915; Dice Metric: 0.2364
Saved new best model to: model/2023-02-15-13-47-37/best_model.pth
epoch 8 training loss: 

In [8]:
wandb.finish()

0,1
Dice Metric,▁▁▂▃▃▃▃▄▄▅▆▆▇▇▇▇▇▇█▇██████████
average training loss,█▇▇▆▆▆▆▅▅▅▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
training loss,██▇▇▆▆▆▆▆▅▅▅▅▄▄▄▃▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▂
validation loss,██▇▇▆▆▆▆▅▅▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
Dice Metric,0.66441
average training loss,0.24303
epoch,30.0
lr,0.0001
training loss,0.24068
validation loss,0.18155
