In [7]:
import os
import sys
from tqdm import tqdm
from tensorboardX import SummaryWriter
import shutil
import argparse
import logging
import time
import random
import numpy as np

import torch
import torch.optim as optim
from torchvision import transforms
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

In [8]:
from networks.vnet import VNet
# from utils.losses import dice_loss
from utils import ramps, losses
from dataloaders.la_heart_sitk import LAHeart, RandomScale, RandomNoise, RandomCrop, CenterCrop, RandomRot, RandomFlip, ToTensor, TwoStreamBatchSampler
# -

In [9]:
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str, default='../../data/gz_dataset/segmented', help='Name of Experiment')
parser.add_argument('--exp', type=str,  default='vnet_supervisedonly_dp', help='model_name')
parser.add_argument('--max_iterations', type=int,  default=6000, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int, default=2, help='batch_size per gpu')
parser.add_argument('--base_lr', type=float,  default=0.01, help='maximum epoch number to train')
parser.add_argument('--deterministic', type=int,  default=1, help='whether use deterministic training')
parser.add_argument('--seed', type=int,  default=1337, help='random seed')
parser.add_argument('--gpu', type=str,  default='0', help='GPU to use')
args = parser.parse_args(args=[])

In [10]:
train_data_path = args.root_path
snapshot_path = "../model/" + args.exp + "/"

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
batch_size = args.batch_size * len(args.gpu.split(','))
max_iterations = args.max_iterations
base_lr = args.base_lr

In [11]:
if args.deterministic:
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

num_classes = 4
patch_size = (128, 128, 64)
cls_weights = [1,4,10]

In [12]:
if not os.path.exists(snapshot_path):
    os.makedirs(snapshot_path)
if os.path.exists(snapshot_path + '/code'):
    shutil.rmtree(snapshot_path + '/code')
shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__']))

logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO,
                    format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))

net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True)
print("os.environ: ",os.environ['CUDA_VISIBLE_DEVICES'])
net = net.cuda()

db_train = LAHeart(base_dir=train_data_path,
                    split='train',
                    #num=16,
                    transform = transforms.Compose([
                        #RandomScale(ratio_low=0.6, ratio_high=1.5),
                        RandomNoise(mu=0, sigma=0.05),
                        RandomRot(),
                        RandomFlip(),
                        RandomCrop(patch_size),
                        ToTensor(),
                        ]))
db_test = LAHeart(base_dir=train_data_path,
                    split='test',
                    transform = transforms.Compose([
                        CenterCrop(patch_size),
                        ToTensor()
                    ]))
def worker_init_fn(worker_id):
    random.seed(args.seed+worker_id)
trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True,  num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

net.train()
optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)

writer = SummaryWriter(snapshot_path+'/log')
logging.info("{} itertations per epoch".format(len(trainloader)))

iter_num = 0
max_epoch = max_iterations//len(trainloader)+1
lr_ = base_lr
net.train()
for epoch_num in tqdm(range(max_epoch), ncols=70):
    time1 = time.time()
    for i_batch, sampled_batch in enumerate(trainloader):
        time2 = time.time()
        # print('fetch data cost {}'.format(time2-time1))
        volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
        volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
        outputs = net(volume_batch)

        loss_seg = F.cross_entropy( outputs, label_batch, weight=torch.tensor(cls_weights,dtype=torch.float32).cuda() )
        outputs_soft = F.softmax(outputs, dim=1)
        loss_seg_dice = 0
        print('\n')
        for i in range(num_classes):
            loss_mid = losses.dice_loss(outputs_soft[:, i, :, :, :], label_batch == i )
            loss_seg_dice += loss_mid
            print('dice score (1-dice_loss): {:.3f}'.format(1-loss_mid))
        print('dicetotal:{:.3f}'.format( loss_seg_dice))
        loss = 0.5*(loss_seg+loss_seg_dice)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iter_num = iter_num + 1
        writer.add_scalar('lr', lr_, iter_num)
        writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
        writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num)
        writer.add_scalar('loss/loss', loss, iter_num)
        logging.info('iteration %d : loss : %f, loss_seg : %f, loss_seg_dice : %f' % 
                        (iter_num, 
                        loss.item(),
                        loss_seg.item(),
                        loss_seg_dice.item())
                    )
        if iter_num % 50 == 0:
            image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3,0,1,2).repeat(1,3,1,1)
            grid_image = make_grid(image, 5, normalize=True)
            writer.add_image('train/Image', grid_image, iter_num)

            outputs_soft = F.softmax(outputs, 1)
            image = outputs_soft[0, 1:2, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
            grid_image = make_grid(image, 5, normalize=False)
            writer.add_image('train/Predicted_label', grid_image, iter_num)

            image = label_batch[0, :, :, 20:61:10].unsqueeze(0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
            grid_image = make_grid(image, 5, normalize=False)
            writer.add_image('train/Groundtruth_label', grid_image, iter_num)

        ## change lr
        if iter_num % 2500 == 0:
            lr_ = base_lr * 0.1 ** (iter_num // 2500)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_
        if iter_num % 1000 == 0:
            save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth')
            torch.save(net.state_dict(), save_mode_path)
            logging.info("save model to {}".format(save_mode_path))

        if iter_num > max_iterations:
            break
        time1 = time.time()
    if iter_num > max_iterations:
        break
save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations+1)+'.pth')
torch.save(net.state_dict(), save_mode_path)
logging.info("save model to {}".format(save_mode_path))
writer.close()

Namespace(base_lr=0.01, batch_size=2, deterministic=1, exp='vnet_supervisedonly_dp', gpu='0', max_iterations=6000, root_path='../../data/gz_dataset/segmented', seed=1337)
Namespace(base_lr=0.01, batch_size=2, deterministic=1, exp='vnet_supervisedonly_dp', gpu='0', max_iterations=6000, root_path='../../data/gz_dataset/segmented', seed=1337)
os.environ:  0
total 32 samples
total 8 samples
16 itertations per epoch
16 itertations per epoch



  0%|                                         | 0/376 [00:00<?, ?it/s][A

RuntimeError: weight tensor should be defined either for all or no classes at /pytorch/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:27