In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from dataset import *
from model import *
from loss import *
import os
import SimpleITK as sitk
%matplotlib widget

In [2]:
mode='gpu'

In [3]:
if mode=='gpu':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # after switch device, you need restart the kernel
#     torch.cuda.set_device(1)
    torch.set_default_tensor_type('torch.cuda.DoubleTensor')
else:
    device = torch.device('cpu')
    torch.set_default_dtype(torch.float64)

1. For classifications(segmentation=voxel-wise classification), `F.softmax(output, dim=1)` is very necessary at the end of the model, as it constraints the output into a probability, or you may have negative value that you also have no clue where it comes from.
2. The numerator in dice loss for each category is very much like the cross entropy: a softmax vector inner product with a one-hot vector - only the value at where one is matters.
2. For segmentation, use dice loss.

## Training
### initialization

In [4]:
resume = True
save_model = True
print(f'resume:{resume}, save_model:{save_model}')
output_dir = 'Models/Fnet'
if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

resume:True, save_model:True


In [5]:
epoch_loss_list = []
epoch_num = 101
start_epoch_num = 7
batch_size = 12
learning_rate = 5e0

model = FNet()
model.train()
if mode=='gpu':
    model.cuda()
net = torch.nn.DataParallel(model, device_ids=[0, 1])
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate)

dataset = FnetDataset(root_dir='/home/sci/hdai/Projects/Dataset/LymphNodes')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

if resume:
    checkpoint = torch.load(f'{output_dir}/epoch_{start_epoch_num-1}_checkpoint.pth.tar')    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    with open(f'{output_dir}/loss.txt', 'a') as f:
        f.write(f'From {start_epoch_num} to {epoch_num+start_epoch_num}\n')
        f.write(f'Dice; Adadelta, lr={learning_rate}; batch size: {batch_size}\n')
else:
    start_epoch_num = 0  
    
    with open(f'{output_dir}/loss.txt', 'w+') as f:
        f.write(f'From {start_epoch_num} to {epoch_num+start_epoch_num}\n')
        f.write(f'Dice; Adadelta: lr={learning_rate}; batch size: {batch_size}\n')
    
print(f'Starting from iteration {start_epoch_num} to iteration {epoch_num+start_epoch_num}')

# params 176982, # conv layers 40
Starting from iteration 7 to iteration 108


### process

In [6]:
for epoch in tqdm(range(start_epoch_num, start_epoch_num+epoch_num)):
    epoch_loss = 0
            
    for i, batched_sample in tqdm(enumerate(dataloader)):
        '''innerdomain backpropagate'''
#         print(i)
        input0 = batched_sample['img0'].double()#.to(device)
        input1 = batched_sample['img1'].double()#.to(device)
        input2 = batched_sample['img2'].double()#.to(device)
        input3 = batched_sample['img3'].double()#.to(device)
#         print(input.shape)
        input0.requires_grad = True
        input1.requires_grad = True
        input2.requires_grad = True
        input3.requires_grad = True
        # u_pred: [batch_size, *data_shape, feature_num] = [1, 5, ...]
        output_pred = net(input0,input1,input2,input3)
        output_true = batched_sample['mask']#.to(device)#.double()
#         print(output_pred.shape, output_true.shape)
    
        optimizer.zero_grad()
#         loss = criterion(output_pred, output_true.squeeze())
        loss = criterion(output_pred, output_true.double())
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
        
    with open(f'{output_dir}/loss.txt', 'a') as f:
        f.write(f'{epoch_loss}\n')
    
    print(f'epoch {epoch} innerdomain loss: {epoch_loss}')#, norm: {torch.norm(f_pred,2)**2}
    epoch_loss_list.append(epoch_loss)
    if epoch%1==0:       
        if save_model:
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': epoch_loss,
            }, f'{output_dir}/epoch_{epoch}_checkpoint.pth.tar')

  0%|          | 0/101 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [02:53, 173.77s/it][A
2it [05:32, 164.99s/it][A
3it [08:05, 159.62s/it][A
4it [10:42, 158.58s/it][A
5it [13:18, 157.45s/it][A
6it [15:46, 154.47s/it][A
7it [18:18, 153.61s/it][A
8it [19:27, 145.98s/it][A


epoch 7 innerdomain loss: 4.076993491798993


  1%|          | 1/101 [19:28<32:27:27, 1168.47s/it]
0it [00:00, ?it/s][A
1it [01:53, 113.03s/it][A
2it [03:39, 108.96s/it][A
3it [05:32, 111.14s/it][A
4it [07:25, 111.75s/it][A
5it [09:21, 113.38s/it][A
6it [11:10, 111.79s/it][A
7it [13:14, 115.65s/it][A
8it [14:04, 105.56s/it][A


epoch 8 innerdomain loss: 4.079719769197151


  2%|▏         | 2/101 [33:33<26:54:23, 978.42s/it] 
0it [00:00, ?it/s][A
1it [01:44, 104.96s/it][A
2it [03:28, 104.15s/it][A
3it [05:18, 106.86s/it][A
4it [06:57, 103.81s/it][A
5it [08:43, 104.56s/it][A
6it [10:24, 103.39s/it][A
7it [12:10, 104.04s/it][A
8it [12:51, 96.43s/it] [A


epoch 9 innerdomain loss: 4.0776736564132205


  3%|▎         | 3/101 [46:26<24:04:29, 884.38s/it]
0it [00:00, ?it/s][A
1it [01:47, 107.10s/it][A
2it [03:38, 109.86s/it][A
3it [05:22, 106.98s/it][A
4it [07:08, 106.47s/it][A
5it [08:56, 107.15s/it][A
6it [10:32, 103.35s/it][A
7it [12:13, 102.54s/it][A
8it [12:52, 96.55s/it] [A


epoch 10 innerdomain loss: 4.075621725251256


  4%|▍         | 4/101 [59:19<22:38:57, 840.59s/it]
0it [00:00, ?it/s][A
1it [01:44, 104.09s/it][A
2it [03:29, 104.74s/it][A
3it [05:17, 106.43s/it][A
4it [06:56, 103.24s/it][A
5it [08:41, 103.91s/it][A
6it [10:28, 104.92s/it][A
7it [12:09, 103.73s/it][A
8it [12:51, 96.38s/it] [A


epoch 11 innerdomain loss: 4.079436348448593


  5%|▍         | 5/101 [1:12:12<21:45:33, 815.97s/it]
0it [00:00, ?it/s][A
1it [01:43, 103.19s/it][A
2it [03:28, 104.68s/it][A
3it [05:13, 104.40s/it][A
4it [06:56, 103.97s/it][A
5it [08:42, 104.74s/it][A
6it [10:24, 103.82s/it][A
7it [12:10, 104.64s/it][A
8it [12:47, 95.96s/it] [A


epoch 12 innerdomain loss: 4.076389387851168


  6%|▌         | 6/101 [1:25:01<21:06:40, 800.01s/it]
0it [00:00, ?it/s][A
1it [01:47, 107.17s/it][A
2it [03:29, 104.43s/it][A
3it [05:14, 104.52s/it][A
4it [06:57, 103.80s/it][A
5it [08:40, 103.82s/it][A
6it [10:19, 102.14s/it][A
7it [11:59, 101.19s/it][A
8it [12:39, 94.90s/it] [A


epoch 13 innerdomain loss: 4.082530263907854


  7%|▋         | 7/101 [1:37:41<20:32:48, 786.90s/it]
0it [00:00, ?it/s][A
1it [01:35, 95.03s/it][A
2it [03:18, 100.10s/it][A
3it [04:51, 96.92s/it] [A
4it [06:19, 93.39s/it][A
5it [08:06, 98.31s/it][A
6it [09:48, 99.50s/it][A
7it [11:27, 99.37s/it][A
8it [12:05, 90.70s/it][A


epoch 14 innerdomain loss: 4.0868927567141204


  8%|▊         | 8/101 [1:49:47<19:49:49, 767.63s/it]
0it [00:00, ?it/s][A
1it [01:38, 98.27s/it][A
2it [03:18, 99.15s/it][A
3it [05:01, 101.15s/it][A
4it [06:43, 101.49s/it][A
5it [08:23, 100.80s/it][A
6it [10:03, 100.71s/it][A
7it [11:41, 99.89s/it] [A
8it [12:17, 92.18s/it][A


epoch 15 innerdomain loss: 4.076055617858074


  9%|▉         | 9/101 [2:02:06<19:23:10, 758.59s/it]
0it [00:00, ?it/s][A
1it [01:39, 99.96s/it][A
2it [03:18, 99.36s/it][A
3it [04:56, 98.42s/it][A
4it [06:36, 99.23s/it][A
5it [08:19, 100.70s/it][A
6it [10:03, 101.55s/it][A
7it [11:37, 99.09s/it] [A
8it [12:15, 91.93s/it][A
 10%|▉         | 10/101 [2:14:22<18:59:56, 751.61s/it]

epoch 16 innerdomain loss: 4.08032973451937



0it [00:00, ?it/s][A
1it [01:39, 99.80s/it][A
2it [03:21, 100.92s/it][A
3it [06:30, 130.14s/it][A
 10%|▉         | 10/101 [2:20:52<21:22:00, 845.28s/it]


RuntimeError: CUDA out of memory. Tried to allocate 3.00 GiB (GPU 0; 23.65 GiB total capacity; 14.76 GiB already allocated; 844.94 MiB free; 20.54 GiB reserved in total by PyTorch)

In [None]:
print(input_id.shape)
print(output_pred_id.shape, output_true_id.shape)

In [None]:
output_true.min()

In [None]:
plt.figure(figsize=(7,5))
plt.title('Innerdomain loss')
plt.xlabel('epoch')
plt.ylabel('MSE loss')
plt.plot(epoch_loss_list)
plt.savefig(f'{output_dir}/adadelta_loss_1e-1.png')