In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from dataset import *
from model import *
from loss import *
import os
import SimpleITK as sitk
%matplotlib widget

In [2]:
mode='gpu'

In [3]:
if mode=='gpu':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # after switch device, you need restart the kernel
#     torch.cuda.set_device(1)
    torch.set_default_tensor_type('torch.cuda.DoubleTensor')
else:
    device = torch.device('cpu')
    torch.set_default_dtype(torch.float64)

1. For classifications(segmentation=voxel-wise classification), `F.softmax(output, dim=1)` is very necessary at the end of the model, as it constraints the output into a probability, or you may have negative value that you also have no clue where it comes from.
2. The numerator in dice loss for each category is very much like the cross entropy: a softmax vector inner product with a one-hot vector - only the value at where one is matters.
2. For segmentation, use dice loss.

## Training
### initialization

In [4]:
resume = True
save_model = True
print(f'resume:{resume}, save_model:{save_model}')
output_dir = 'Models/Unet'
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

resume:True, save_model:True


In [5]:
epoch_loss_list = []
epoch_num = 1001
start_epoch_num = 5
batch_size = 1
learning_rate = 15

model = UNet64()
model.train()
if mode=='gpu':
    model.cuda()
net = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
# criterion = DiceLoss()
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate)

dataset = UnetDataset(root_dir='/home/sci/hdai/Projects/Dataset/LymphNodes')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

if resume:
    checkpoint = torch.load(f'{output_dir}/epoch_{start_epoch_num-1}_checkpoint.pth.tar')    
    model.load_state_dict(checkpoint['model_state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    with open(f'{output_dir}/loss.txt', 'a') as f:
        f.write(f'From {start_epoch_num} to {epoch_num+start_epoch_num}\n')
        f.write(f'Dice; Adadelta, lr={learning_rate}; batch size: {batch_size}\n')
else:
    start_epoch_num = 0  
    
    with open(f'{output_dir}/loss.txt', 'w+') as f:
        f.write(f'From {start_epoch_num} to {epoch_num+start_epoch_num}\n')
        f.write(f'Dice; Adadelta: lr={learning_rate}; batch size: {batch_size}\n')
    
print(f'Starting from iteration {start_epoch_num} to iteration {epoch_num+start_epoch_num}')

# params 464849, # conv layers 30
Starting from iteration 5 to iteration 1006


### process

In [None]:
for epoch in tqdm(range(start_epoch_num, start_epoch_num+epoch_num)):
    epoch_loss = 0
            
    for i, batched_sample in tqdm(enumerate(dataloader)):
        '''innerdomain backpropagate'''
#         print(i)
        input_data = batched_sample['img'].double()#.to(device)
#         print(input.shape)
        input_data.requires_grad = True
        # u_pred: [batch_size, *data_shape, feature_num] = [1, 5, ...]
        output_pred = net(input_data)
        output_true = batched_sample['mask']#.to(device)#.double()
#         print(output_pred.shape, output_true.shape)
    
        optimizer.zero_grad()
#         loss = criterion(output_pred, output_true.squeeze())
        loss = criterion(output_pred, output_true.double())
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        
    with open(f'{output_dir}/loss.txt', 'a') as f:
        f.write(f'{epoch_loss}\n')
    
    print(f'epoch {epoch} innerdomain loss: {epoch_loss}')#, norm: {torch.norm(f_pred,2)**2}
    epoch_loss_list.append(epoch_loss)
    if epoch%1==0:       
        if save_model:
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
#             'optimizer_bd_state_dict': optimizer_bd.state_dict(),
            'loss': epoch_loss,
#             'loss_bd': epoch_loss_bd
            }, f'{output_dir}/epoch_{epoch}_checkpoint.pth.tar')

  0%|          | 0/1001 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [00:15, 15.38s/it][A
2it [00:33, 16.99s/it][A
3it [00:51, 17.55s/it][A
4it [01:09, 17.81s/it][A
5it [01:23, 16.23s/it][A
6it [01:37, 15.45s/it][A
7it [01:52, 15.50s/it][A
8it [02:09, 15.82s/it][A
9it [02:23, 15.28s/it][A
10it [02:37, 14.86s/it][A
11it [02:53, 15.37s/it][A
12it [03:07, 14.79s/it][A
13it [03:21, 14.59s/it][A
14it [03:35, 14.43s/it][A
15it [03:54, 15.77s/it][A
16it [04:12, 16.31s/it][A
17it [04:37, 18.93s/it][A
18it [04:50, 17.14s/it][A
19it [05:15, 19.69s/it][A
20it [05:30, 18.10s/it][A
21it [05:45, 17.36s/it][A
22it [06:02, 17.26s/it][A
23it [06:20, 17.40s/it][A
24it [06:32, 15.84s/it][A
25it [06:45, 15.04s/it][A
26it [06:59, 14.78s/it][A
27it [07:13, 14.31s/it][A
28it [07:25, 13.66s/it][A
29it [07:51, 17.27s/it][A
30it [08:04, 16.17s/it][A
31it [08:21, 16.42s/it][A
32it [08:39, 16.89s/it][A
33it [08:57, 17.18s/it][A
34it [09:10, 16.01s/it][A
35it [09:25, 15.57s/it][A
3

In [None]:
output_pred.device()

In [None]:
print(input_id.shape)
print(output_pred_id.shape, output_true_id.shape)

In [None]:
output_true_id.min()

In [None]:
plt.figure(figsize=(7,5))
plt.title('Innerdomain loss')
plt.xlabel('epoch')
plt.ylabel('MSE loss')
plt.plot(epoch_loss_list)
plt.savefig(f'{output_dir}/adadelta_loss_1e-1.png')