In [None]:
import os
import numpy as np
import datetime

import torch
from torch.utils.data import DataLoader

from csnet.dataset import CS_Dataset
from csnet.losses import WeightedCrossEntropyLoss, DiceLoss
from csnet.model import CSNet3D
from csnet.train import model_eval
from csnet.utils import save_model

In [None]:
dir_train_gt = 'data/train/gt'
dir_train_img = 'data/train/img'
dir_test_gt = 'data/test/gt'
dir_test_img = 'data/test/img'

batch_size = 2
lr = 0.0001
weight_decay = 0.0005
wce_loss_weight = 0.6
dice_loss_weight = 0.4

snapshot = 2
test_step = 2
model_path = 'model'

step_size=1
gamma=0.1

In [None]:
ds = CS_Dataset(dir_train_img, dir_train_gt)
dl = DataLoader(ds, batch_size=batch_size, num_workers=batch_size, shuffle=True)

ds_test = CS_Dataset(dir_test_img, dir_test_gt)
dl_test = DataLoader(ds_test, batch_size=batch_size, num_workers=batch_size, shuffle=False)

net = CSNet3D(classes=2, channels=1).cuda()
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                                    step_size=step_size, gamma=gamma)

wce_loss = WeightedCrossEntropyLoss().cuda()
dice_loss = DiceLoss().cuda()

model_name = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
os.makedirs(os.path.join(model_path, model_name))

In [None]:
for epoch in range(10):
    net.train()
    tr_loss = []
    for idx, batch in enumerate(dl):
        image = batch[0].cuda()
        label = batch[1].cuda()
        optimizer.zero_grad()
        pred = net(image)
        loss = (wce_loss_weight * wce_loss(pred, label.squeeze(1)) 
                + dice_loss_weight * dice_loss(pred, label))
        loss.backward()
        optimizer.step()
        tr_loss.append(loss.item())
    lr_scheduler.step()
    
    print(rf"Epoch {epoch + 1}, training loss: {np.mean(tr_loss)}")
    
    
    if (epoch + 1) % test_step == 0:
        val_loss, recall, precision, iou = model_eval(net, dl_test, 
                                                      wce_loss_weight, dice_loss_weight, 
                                                      wce_loss, dice_loss)
        print(rf"Epoch {epoch + 1}, val loss: {val_loss};"\
              rf" recall: {recall}; precision: {precision}; IOU: {iou}")
        
    if (epoch + 1) % snapshot == 0:
        save_model(net, epoch + 1, model_path, model_name)
        