In [1]:
import math
import numpy as np
import torch
import pkbar
import sys
from unet3d.config import *
from tqdm import tqdm
from torch.nn import CrossEntropyLoss
from torch.nn.functional import one_hot
from torch.optim import Adam
from torchsummary import summary
#from torch.utils.tensorboard import SummaryWriter
from unet3d.unet3d_vgg16 import UNet3D_VGG16
from utils.Visualization import ImageSliceViewer3D
from patchify import patchify
import nrrd

torch.manual_seed(0)

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
# Load some data and uniformly sample from it
scan, _ = nrrd.read(DATASET_PATH + '/SAIAD 1/scan.nrrd')
segm, _ = nrrd.read(DATASET_PATH + '/SAIAD 1/segm.nrrd')

#scan_patches = patchify(scan, PATCH_SIZE, step=PATCH_SIZE).reshape(-1,PATCH_SIZE[0],PATCH_SIZE[1], PATCH_SIZE[2])

## Random Sampling: Uniform
scan_patches = []
segm_patches = []
side_len = PATCH_SIZE[0]
for i in tqdm(range(128)):
    # Center coordinates
    cx = torch.randint(0,scan.shape[0],(1,))[0]
    cy = torch.randint(0,scan.shape[1],(1,))[0]
    cz = torch.randint(0,scan.shape[2],(1,))[0]
    
    #print(f"Center: {[cx,cy,cz]}")
    bbox_x = [max(cx - side_len//2, 0), min(scan.shape[0], cx+side_len//2)]
    bbox_y = [max(cy - side_len//2, 0), min(scan.shape[1], cy+side_len//2)]
    bbox_z = [max(cz - side_len//2, 0), min(scan.shape[2], cz+side_len//2)]

    # Random patch
    pad_x = (-min(cx - side_len//2,0), max(side_len//2 + cx - scan.shape[0], 0))
    pad_y = (-min(cy - side_len//2,0), max(side_len//2 + cy - scan.shape[1], 0))
    pad_z = (-min(cz - side_len//2,0), max(side_len//2 + cz - scan.shape[2], 0))
    
    #print([pad_x, pad_y, pad_z])

    segm_patch_prepad = segm[bbox_x[0]:bbox_x[1], bbox_y[0]:bbox_y[1], bbox_z[0]:bbox_z[1]]
    scan_patch_prepad = scan[bbox_x[0]:bbox_x[1], bbox_y[0]:bbox_y[1], bbox_z[0]:bbox_z[1]]
    scan_patch = np.pad(scan_patch_prepad,(pad_x, pad_y, pad_z), 'constant', constant_values=0)
    segm_patch = np.pad(segm_patch_prepad,(pad_x, pad_y, pad_z), 'constant', constant_values=0)
    
    scan_patches.append(scan_patch)
    segm_patches.append(segm_patch)
    
scan_patches = torch.tensor(np.array(scan_patches)).float()
scan_patches = torch.unsqueeze(scan_patches,1) # add channel dimension, send to gpu
segm_patches = np.array(segm_patches).reshape(-1,PATCH_SIZE[0],PATCH_SIZE[1], PATCH_SIZE[2])
segm_patches = one_hot(torch.tensor(segm_patches).to(torch.int64), num_classes=NUM_CLASSES).permute(0,4,1,2,3).float()# put channels first, send to gpu

print([scan_patches.shape, scan_patches.dtype])
print([segm_patches.shape, segm_patches.dtype])


100%|█████████████████████████████████████████████████████████████████████| 128/128 [00:01<00:00, 124.88it/s]


[torch.Size([128, 1, 96, 96, 96]), torch.float32]
[torch.Size([128, 5, 96, 96, 96]), torch.float32]


In [3]:
# Training
torch.backends.cudnn.benchmark = True # Speeds up stuff
torch.backends.cudnn.enabled = True

device = torch.device('cuda')
model = UNet3D_VGG16(in_channels=IN_CHANNELS , num_classes= NUM_CLASSES).cuda()
criterion = CrossEntropyLoss(weight=torch.Tensor(np.array(CE_WEIGHTS)/np.array(CE_WEIGHTS).sum())).cuda()
optimizer = Adam(params=model.parameters())

min_valid_loss = math.inf

for epoch in range(TRAINING_EPOCH):
    # progress bar
    #kbar = pkbar.Kbar(target=120, epoch=epoch, num_epochs=TRAINING_EPOCH, width=8, always_stateful=False)

    train_loss = 0.0
    model.train()
    for i in range(0,120,TRAIN_BATCH_SIZE):
        image, ground_truth = scan_patches[i:i+TRAIN_BATCH_SIZE], segm_patches[i:i+TRAIN_BATCH_SIZE]
    
        optimizer.zero_grad(set_to_none=True)
        
        target = model(image.cuda())
        loss = criterion(target, ground_truth.cuda())
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        #kbar.update(i, values=[("loss", train_loss)])
    
    valid_loss = 0.0
    model.eval()
    for i in range(120,128,TRAIN_BATCH_SIZE):
        image, ground_truth = scan_patches[i:i+TRAIN_BATCH_SIZE], segm_patches[i:i+TRAIN_BATCH_SIZE]
        target = model(image.cuda())
        loss = criterion(target,ground_truth.cuda())
        valid_loss += loss.item()
    #kbar.update(i, values=[("Validation loss", valid_loss)])

        
    #writer.add_scalar("Loss/Train", train_loss / 120, epoch)
    #writer.add_scalar("Loss/Validation", valid_loss / 8, epoch)
    
    #print(f'Epoch {epoch+1} \t\t Training Loss: {train_loss / 8} \t\t Validation Loss: {valid_loss / 8}')
    
    if min_valid_loss > valid_loss:
        print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
        min_valid_loss = valid_loss
        # Saving State Dict
        torch.save(model.state_dict(), f'checkpoints/test_epoch{epoch}_valLoss{min_valid_loss}.pth')


  out = self.softmax(out)


Validation Loss Decreased(inf--->0.031945) 	 Saving The Model
Validation Loss Decreased(0.031945--->0.031413) 	 Saving The Model
Validation Loss Decreased(0.031413--->0.031326) 	 Saving The Model


KeyboardInterrupt: 

In [4]:
from utils.Visualization import ImageSliceViewer3D
n=1
pred = model(scan_patches[n:n+1].cuda())
pred.size()
pred_index = np.array(torch.argmax(pred[0].cpu(), dim=0))
ImageSliceViewer3D(pred_index, np.array(scan_patches[n:n+1][0,0]))

interactive(children=(RadioButtons(description='Slice plane selection:', options=('x-y', 'y-z', 'z-x'), style=…

<utils.Visualization.ImageSliceViewer3D at 0x147d9c702d60>

In [9]:
n=1
ImageSliceViewer3D(np.array(segm_patches[n:n+1][0,0]))

interactive(children=(RadioButtons(description='Slice plane selection:', options=('x-y', 'y-z', 'z-x'), style=…

<utils.Visualization.ImageSliceViewer3D at 0x1494635c47c0>

In [None]:
np.unique(scan)

In [10]:
idx = np.arange(128)
np.random.shuffle(idx)
print(idx)

[ 36 125  77  99  75  91  95 110  57 102  11 111  80  97  40  17 126  53
   2  54  30 123 116  28  23  93  81  88  49  46  96  62  90 124  61   3
   7  86  34   6  55  84 109   9  20   5  78  67 112  64  25  44  10  92
  66  98 108  59  41  50  24  63 117  76 104 120  56   8  26  43  33  37
  89   4 106  14   1  35 107 114   0  45  94  69 100  83  18  60  12 119
  70  51  31  87  16  22 115  13  38 127  42  68  74 118 103 121 101 122
  15  19  85  73  58  29  21  52 105  39  82  79  47  71 113  72  48  65
  27  32]


In [5]:
import time
torch.cuda.empty_cache()
start_time = time.time()
summary(model=model, input_size=(1, 96,96,96), batch_size=-1, device="cuda")
print("--- %s seconds ---" % (time.time() - start_time))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 32, 96, 96, 96]             896
       BatchNorm3d-2       [-1, 32, 96, 96, 96]              64
              ReLU-3       [-1, 32, 96, 96, 96]               0
            Conv3d-4       [-1, 64, 96, 96, 96]          55,360
       BatchNorm3d-5       [-1, 64, 96, 96, 96]             128
              ReLU-6       [-1, 64, 96, 96, 96]               0
         MaxPool3d-7       [-1, 64, 48, 48, 48]               0
       Conv3DBlock-8  [[-1, 64, 48, 48, 48], [-1, 64, 96, 96, 96]]               0
            Conv3d-9       [-1, 64, 48, 48, 48]         110,656
      BatchNorm3d-10       [-1, 64, 48, 48, 48]             128
             ReLU-11       [-1, 64, 48, 48, 48]               0
           Conv3d-12      [-1, 128, 48, 48, 48]         221,312
      BatchNorm3d-13      [-1, 128, 48, 48, 48]             256
             ReLU-14