In [None]:
import os
import datetime

import torch
from torch import optim
from torch.utils.data import DataLoader

from csnet.dataset import CS_Dataset
from csnet.losses import WeightedCrossEntropyLoss, DiceLoss
from csnet.model import CSNet3D

In [None]:
datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

In [None]:
dir_train_gt = 'data/train/gt'
dir_train_img = 'data/train/img'
dir_test_gt = 'data/test/gt'
dir_test_img = 'data/test/img'

batch_size = 2
lr = 0.0001
weight_decay = 0.0005
wce_loss_weight = 0.6
dice_loss_weight = 0.4

snapshot = 2
model_path = 'model'

In [None]:
ds = CS_Dataset(dir_train_img, dir_train_gt)
dl = DataLoader(ds, batch_size=batch_size, num_workers=batch_size, shuffle=True)
net = CSNet3D(classes=2, channels=1).cuda()
optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

wce_loss = WeightedCrossEntropyLoss().cuda()
dice_loss = DiceLoss().cuda()

model_name = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
os.makedirs(os.path.join(model_path, model_name))

In [None]:
def save_ckpt(net, epoch, model_path, model_name):
    fn_out = os.path.join(model_path, model_name, rf"{model_name}_{epoch}.pkl")
    torch.save(net, fn_out)
    print(rf"Saved model to: {fn_out}")

In [None]:
for epoch in range(10):
    print(epoch)
    net.train()
    for idx, batch in enumerate(dl):
        image = batch[0].cuda()
        label = batch[1].cuda()
        optimizer.zero_grad()
        pred = net(image)
        loss = (wce_loss_weight * wce_loss(pred, label.squeeze(1)) 
                + dice_loss_weight * dice_loss(pred, label))
        loss.backward()
        optimizer.step()
        
    if (epoch + 1) % snapshot == 0:
        save_ckpt(net, epoch + 1, model_path, model_name)