In [1]:
import argparse
import json
import os

import numpy as np
import wandb
from monai.data import CacheDataset, DataLoader
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.networks.layers import Norm
from monai.networks.nets import UNet

from csnet.models.csnet import CSNet
from csnet.models.csnet_orig import CSNetOrig
from csnet.transforms.default import get_default_train_transforms
from csnet.utils import get_data_dict, get_model_name
from csnet.utils.train import train

### Specify parametersCSNet

In [2]:
data_dir = 'data/semantic_3D'
train_dir = 'train'
val_dir = 'val'
model_path = 'model_test'

wandb_project = 'CSNet Test'

log_progress = True

model = 'unet' # one of ["unet", "csnet", "csnet_orig"]

model_channels = (16, 32, 64, 128)
num_residual_units = 1

In [3]:
config = dict(
        epochs=5,
        batch_size=2,
        lr=0.0001,
        weight_decay=0.0005,
        factor=0.1,
        patience=2,
        roi_size=(32, 64, 64),
        model_path=model_path,
        metric_name='Dice Metric',
        model=model,
        n_channels=model_channels,
        num_res_units=num_residual_units,
        log_progress=log_progress,
        wandb_project=wandb_project,
        data_dir=data_dir,
        train_dirname=train_dir,
        val_dirname=val_dir,
        img_dirname='img',
        gt_dirname='gt'
    )
config

{'epochs': 5,
 'batch_size': 2,
 'lr': 0.0001,
 'weight_decay': 0.0005,
 'factor': 0.1,
 'patience': 2,
 'roi_size': (32, 64, 64),
 'model_path': 'model_test',
 'metric_name': 'Dice Metric',
 'model': 'unet',
 'n_channels': (16, 32, 64, 128),
 'num_res_units': 1,
 'log_progress': True,
 'wandb_project': 'CSNet Test',
 'data_dir': 'data/semantic_3D',
 'train_dirname': 'train',
 'val_dirname': 'val',
 'img_dirname': 'img',
 'gt_dirname': 'gt'}

### Initialize wandb project

In [4]:
config = argparse.Namespace(**config)
if config.log_progress:
    with open('/home/amedyukh/.wandb_api_key') as f:
        key = f.read()
    os.environ['WANDB_API_KEY'] = key
else:
    os.environ['WANDB_MODE'] = 'offline'

wandb.init(project=config.wandb_project, config=vars(config))

# Update model path
config.model_path = os.path.join(config.model_path, get_model_name(config.log_progress))

# Save training parameters
os.makedirs(config.model_path, exist_ok=True)
with open(os.path.join(config.model_path, 'config.json'), 'w') as f:
    json.dump(vars(config), f, indent=4)
          


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mamedyukh[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.21 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


### Setup model, loss and metric

In [5]:
if config.model.lower() == 'unet':
    net = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
        channels=config.n_channels,
        strides=(2,) * (len(config.n_channels) - 1),
        num_res_units=config.num_res_units,
        norm=Norm.BATCH,
    )
elif config.model.lower() == 'csnet':
    net = CSNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
        channels=config.model_channels,
        strides=(2,) * (len(config.n_channels) - 1),
        num_res_units=config.num_res_units,
        norm=Norm.BATCH,
    )
elif config.model.lower() == 'csnet_orig':
    net = CSNetOrig(2, 1)
else:
    raise NotImplementedError(
        rf'{config.model} is an invalid model; must be one of ["unet", "csnet", "csnet_orig"]')

loss_function = DiceLoss(to_onehot_y=True, softmax=True)
dice_metric = DiceMetric(include_background=False, reduction="mean")

### Setup data loaders

In [6]:
# Train and validation transforms
train_tr, val_tr = get_default_train_transforms(roi_size=config.roi_size)

# Training and validation file lists
train_files = get_data_dict(os.path.join(config.data_dir, config.train_dirname, config.img_dirname),
                            os.path.join(config.data_dir, config.train_dirname, config.gt_dirname))
val_files = get_data_dict(os.path.join(config.data_dir, config.val_dirname, config.img_dirname),
                          os.path.join(config.data_dir, config.val_dirname, config.gt_dirname))

# Dataset and dataloader for training
tr_ds = CacheDataset(data=train_files, transform=train_tr, cache_rate=1, num_workers=2 * config.batch_size)
train_dl = DataLoader(tr_ds, batch_size=config.batch_size, shuffle=True, num_workers=2 * config.batch_size)

# Dataset and dataloader for validation
val_ds = CacheDataset(data=val_files, transform=val_tr, cache_rate=1.0, num_workers=2 * config.batch_size)
val_dl = DataLoader(val_ds, batch_size=config.batch_size, num_workers=2 * config.batch_size)

Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [00:00<00:00, 750322.72it/s]
Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [00:00<00:00, 623533.80it/s]


### Train

In [None]:
train(train_dl, val_dl, net, loss_function, dice_metric, config, log_tensorboard=True)

epoch 1 training loss: 0.5894
epoch 1 validation loss: 0.5674; Dice Metric: 0.0023
Saved new best model to: model_test/dry-bush-38/best_model.pth


In [None]:
wandb.finish()