In [None]:
import argparse
import os
import json

import wandb

from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from monai.data import CacheDataset, DataLoader

from csnet.utils.train import train
from csnet.utils import get_data_dict, get_model_name
from csnet.transforms.default import get_default_train_transforms
from csnet.models.csnet import CSNet
from csnet.models.csnet_orig import CSNetOrig

### Specify parametersCSNet

In [None]:
data_dir = 'data/semantic_3D'
train_dir = 'train'
val_dir = 'val'
model_path = 'model_test'

wandb_project = 'CSNet Test'

log_progress = True

model = 'unet' # one of ["unet", "csnet", "csnet_orig"]

model_channels = (16, 32, 64, 128)
num_residual_units = 1

In [None]:
config = dict(
        epochs=5,
        batch_size=2,
        lr=0.0001,
        weight_decay=0.0005,
        factor=0.1,
        patience=2,
        roi_size=(32, 64, 64),
        model_path=model_path,
        metric_name='Dice Metric',
        model=model,
        model_channels=model_channels,
        num_residual_units=num_residual_units
    )
config

### Initialize wandb project

In [None]:
if log_progress:
    with open('/home/amedyukh/.wandb_api_key') as f:
        key = f.read()
    os.environ['WANDB_API_KEY'] = key
else:
    os.environ['WANDB_MODE'] = 'offline'

wandb.init(project=wandb_project, config=config)
config['model_path'] = os.path.join(config['model_path'], get_model_name(log_progress))

os.makedirs(config['model_path'], exist_ok=True)
with open(os.path.join(config['model_path'], 'config.json'), 'w') as f:
    json.dump(config, f, indent=4)
          
config = argparse.Namespace(**config)

### Setup model, loss and metric

In [None]:
if model.lower() == 'unet':
    net = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
        channels=model_channels,
        strides=(2,)*(len(model_channels) - 1),
        num_res_units=num_residual_units,
        norm=Norm.BATCH,
    )
elif model.lower() == 'csnet':
    net = CSNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,
        channels=model_channels,
        strides=(2,)*(len(model_channels) - 1),
        num_res_units=num_residual_units,
        norm=Norm.BATCH,
    )
elif model.lower() == 'csnet_orig':
    net = CSNetOrig(2, 1)
else:
    raise NotImplementedError(rf'{model} is an invalid model; must be one of ["unet", "csnet", "csnet_orig"]')


loss_function = DiceLoss(to_onehot_y=True, softmax=True)
dice_metric = DiceMetric(include_background=False, reduction="mean")

### Setup data loaders

In [None]:
train_tr, val_tr = get_default_train_transforms(roi_size=config.roi_size)

In [None]:
train_files = get_data_dict(os.path.join(data_dir, train_dir, 'img'), 
                            os.path.join(data_dir, train_dir, 'gt'))
val_files = get_data_dict(os.path.join(data_dir, val_dir, 'img'), 
                          os.path.join(data_dir, val_dir, 'gt'))

In [None]:
tr_ds = CacheDataset(data=train_files, transform=train_tr, cache_rate=1, num_workers=2*config.batch_size)
train_dl = DataLoader(tr_ds, batch_size=config.batch_size, shuffle=True, num_workers=2*config.batch_size)

val_ds = CacheDataset(data=val_files, transform=val_tr, cache_rate=1.0, num_workers=2*config.batch_size)
val_dl = DataLoader(val_ds, batch_size=config.batch_size, num_workers=2*config.batch_size)

### Train

In [None]:
train(train_dl, val_dl, net, loss_function, dice_metric, config, log_tensorboard=True)

In [None]:
wandb.finish()