In [1]:
from json import load
import math
import numpy as np
import torch
import pkbar
import sys
import pkbar
from unet3d.config import *
from tqdm import tqdm
from torch.nn import CrossEntropyLoss
from torch.nn.functional import one_hot
from torch.optim import Adam
from unet3d.unet3d_vgg16 import UNet3D_VGG16
from utils.Other import get_headers
from unet3d.dataset import SAIADDataset, WrappedDataLoader, to_device
from torch.utils.data import DataLoader
from pynvml.smi import nvidia_smi

torch.manual_seed(0)

%load_ext autoreload
%autoreload 2
%matplotlib inline

torch.backends.cudnn.benchmark = True # Speeds up stuff
torch.backends.cudnn.enabled = True
device = torch.device('cuda')
nvsmi = nvidia_smi.getInstance()

_,_,patient_names = get_headers(DATASET_PATH)

### FOR TESTING ###
#TRAIN_BATCHES_PER_EPOCH=20
#VAL_BATCHES_PER_EPOCH=5


In [2]:
excl_patients_training = ['SAIAD 15', 'SAIAD 11'] #patients for validation/testing
excl_patients_val = list(set(patient_names) - set(excl_patients_training))

print("Training with val patients:", excl_patients_training)



## Load dataset ##
train_dataset = SAIADDataset(
    excl_patients=excl_patients_training,
    load_data_to_memory=True,
    n_batches=TRAIN_BATCHES_PER_EPOCH,
    )
val_dataset = SAIADDataset(
    excl_patients=excl_patients_val,
    load_data_to_memory=True,
    n_batches=VAL_BATCHES_PER_EPOCH,
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=False, 
    pin_memory=False, ###
    num_workers=NUM_WORKERS
    )
val_dataloader = DataLoader(
    val_dataset, 
    batch_size=VAL_BATCH_SIZE,
    shuffle=False, 
    pin_memory=False, ###
    num_workers=NUM_WORKERS
    )

val_dataloader = WrappedDataLoader(val_dataloader, to_device, device)
train_dataloader = WrappedDataLoader(train_dataloader, to_device, device)


Training with val patients: ['SAIAD 15', 'SAIAD 11']
Fetching patients probabilities...


100%|███████████████████████████████████████████| 20/20 [00:06<00:00,  3.03it/s]


Fetching patients probabilities...


100%|█████████████████████████████████████████████| 2/2 [00:00<00:00,  2.38it/s]


In [None]:
## Model ##
model = UNet3D_VGG16(
    in_channels=IN_CHANNELS , 
    num_classes=NUM_CLASSES,
    use_softmax_end=False #set this to false for training with CELoss
    ).to(device)

loss_fn = CrossEntropyLoss(weight=torch.Tensor(np.array(CE_WEIGHTS)/np.array(CE_WEIGHTS).sum())).cuda()
optimizer = Adam(params=model.parameters(), lr=LR)



## Training ##
min_valid_loss = math.inf

for epoch in range(EPOCHS):
    # Check memory usage
    mem_query = nvsmi.DeviceQuery('memory.free, memory.total')['gpu'][0]['fb_memory_usage']
    print(f"Mem. Usage - Used: {mem_query['total']-mem_query['free'] }/{mem_query['total']} MB")

    # progress bar
    kbar = pkbar.Kbar(target=TRAIN_BATCHES_PER_EPOCH+VAL_BATCHES_PER_EPOCH, epoch=epoch, num_epochs=EPOCHS, width=8, always_stateful=True)

    train_loss = 0.0
    model.train()
    i=1
    batch_num = 1
    for X_batch, y_batch in train_dataloader:  
        pred = model(X_batch)
        loss = loss_fn(pred, y_batch)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        train_loss += loss.cpu().detach()
        kbar.update(i, values=[("loss", train_loss/batch_num)])
        i+=1
        batch_num+=1
        
    # Tensorboard #
    #writer.add_scalar("Loss/train", train_loss, epoch)

    valid_loss = 0.0
    model.eval()
    batch_num = 1
    with torch.no_grad():
        for X_batch, y_batch in val_dataloader:
            pred = model(X_batch)
            loss = loss_fn(pred,y_batch)
            valid_loss += loss.cpu().detach()
            kbar.update(i, values=[("Validation loss", valid_loss/batch_num)])
            i+=1
            batch_num+=1
    valid_loss /= batch_num
            
    # Tensorboard #
    #writer.add_scalar("Loss/val", valid_loss, epoch)

    if min_valid_loss > valid_loss:
        print(f'\t Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
        min_valid_loss = valid_loss
        # Saving State Dict
        torch.save(model.state_dict(), f'checkpoints/no_aug_epoch{epoch}_valLoss{min_valid_loss:.6f}.pth')
    elif (epochs+1)%EPOCHS//10 == 0:
        print(f'\t Reached checkpoint. \t Saving The Model')
        torch.save(model.state_dict(), f'checkpoints/no_aug_epoch{epoch}_valLoss{min_valid_loss:.6f}.pth')




Mem. Usage - Used: 1231.3125/16160.5 MB
Epoch: 1/100
	 Validation Loss Decreased(inf--->0.058866) 	 Saving The Model
Mem. Usage - Used: 15171.3125/16160.5 MB
Epoch: 2/100
Mem. Usage - Used: 15171.3125/16160.5 MB
Epoch: 3/100
Mem. Usage - Used: 14165.3125/16160.5 MB
Epoch: 4/100
Mem. Usage - Used: 14165.3125/16160.5 MB
Epoch: 5/100
Mem. Usage - Used: 14165.3125/16160.5 MB
Epoch: 6/100
	 Validation Loss Decreased(0.058866--->0.056845) 	 Saving The Model
Mem. Usage - Used: 14165.3125/16160.5 MB
Epoch: 7/100
	 Validation Loss Decreased(0.056845--->0.055017) 	 Saving The Model
Mem. Usage - Used: 14165.3125/16160.5 MB
Epoch: 8/100
Mem. Usage - Used: 14165.3125/16160.5 MB
Epoch: 9/100
Mem. Usage - Used: 14165.3125/16160.5 MB
Epoch: 10/100
Mem. Usage - Used: 14165.3125/16160.5 MB
Epoch: 11/100
Mem. Usage - Used: 14165.3125/16160.5 MB
Epoch: 12/100
Mem. Usage - Used: 14165.3125/16160.5 MB
Epoch: 13/100
Mem. Usage - Used: 14165.3125/16160.5 MB
Epoch: 14/100
Mem. Usage - Used: 14165.3125/16160.5 

In [11]:
from utils.Visualization import ImageSliceViewer3D
n=70
pred = model(scan_patches[n:n+1].cuda())
pred.size()
pred_index = np.array(torch.argmax(pred[0].cpu(), dim=0))
ImageSliceViewer3D(pred_index, np.array(scan_patches[n:n+1][0,0]))

interactive(children=(RadioButtons(description='Slice plane selection:', options=('x-y', 'y-z', 'z-x'), style=…

<utils.Visualization.ImageSliceViewer3D at 0x151724f16a90>

In [5]:
loss_fn(pred, segm_patches[n:n+1].cuda())

tensor(0.4688, device='cuda:0', grad_fn=<DivBackward1>)

In [14]:
torch.unique(pred)

tensor([nan, nan, nan,  ..., nan, nan, nan], device='cuda:0',
       grad_fn=<Unique2Backward0>)

In [10]:
idx = np.arange(128)
np.random.shuffle(idx)
print(idx)

[ 36 125  77  99  75  91  95 110  57 102  11 111  80  97  40  17 126  53
   2  54  30 123 116  28  23  93  81  88  49  46  96  62  90 124  61   3
   7  86  34   6  55  84 109   9  20   5  78  67 112  64  25  44  10  92
  66  98 108  59  41  50  24  63 117  76 104 120  56   8  26  43  33  37
  89   4 106  14   1  35 107 114   0  45  94  69 100  83  18  60  12 119
  70  51  31  87  16  22 115  13  38 127  42  68  74 118 103 121 101 122
  15  19  85  73  58  29  21  52 105  39  82  79  47  71 113  72  48  65
  27  32]


In [2]:
# Load some data and uniformly sample from it
scan, _ = nrrd.read(DATASET_PATH + '/SAIAD 1/scan.nrrd')
segm, _ = nrrd.read(DATASET_PATH + '/SAIAD 1/segm.nrrd')

#scan_patches = patchify(scan, PATCH_SIZE, step=PATCH_SIZE).reshape(-1,PATCH_SIZE[0],PATCH_SIZE[1], PATCH_SIZE[2])

## Random Sampling: Uniform
scan_patches = []
segm_patches = []
side_len = PATCH_SIZE[0]
for i in tqdm(range(128)):
    # Center coordinates
    cx = torch.randint(0,scan.shape[0],(1,))[0]
    cy = torch.randint(0,scan.shape[1],(1,))[0]
    cz = torch.randint(0,scan.shape[2],(1,))[0]
    
    #print(f"Center: {[cx,cy,cz]}")
    bbox_x = [max(cx - side_len//2, 0), min(scan.shape[0], cx+side_len//2)]
    bbox_y = [max(cy - side_len//2, 0), min(scan.shape[1], cy+side_len//2)]
    bbox_z = [max(cz - side_len//2, 0), min(scan.shape[2], cz+side_len//2)]

    # Random patch
    pad_x = (-min(cx - side_len//2,0), max(side_len//2 + cx - scan.shape[0], 0))
    pad_y = (-min(cy - side_len//2,0), max(side_len//2 + cy - scan.shape[1], 0))
    pad_z = (-min(cz - side_len//2,0), max(side_len//2 + cz - scan.shape[2], 0))
    
    #print([pad_x, pad_y, pad_z])

    segm_patch_prepad = segm[bbox_x[0]:bbox_x[1], bbox_y[0]:bbox_y[1], bbox_z[0]:bbox_z[1]]
    scan_patch_prepad = scan[bbox_x[0]:bbox_x[1], bbox_y[0]:bbox_y[1], bbox_z[0]:bbox_z[1]]
    scan_patch = np.pad(scan_patch_prepad,(pad_x, pad_y, pad_z), 'constant', constant_values=0)
    segm_patch = np.pad(segm_patch_prepad,(pad_x, pad_y, pad_z), 'constant', constant_values=0)
    
    scan_patches.append(scan_patch)
    segm_patches.append(segm_patch)
    
scan_patches = torch.tensor(np.array(scan_patches)).float()
scan_patches = torch.unsqueeze(scan_patches,1) # add channel dimension, send to gpu
segm_patches = np.array(segm_patches).reshape(-1,PATCH_SIZE[0],PATCH_SIZE[1], PATCH_SIZE[2])
segm_patches = one_hot(torch.tensor(segm_patches).to(torch.int64), num_classes=NUM_CLASSES).permute(0,4,1,2,3).float()# put channels first, send to gpu

print([scan_patches.shape, scan_patches.dtype])
print([segm_patches.shape, segm_patches.dtype])


100%|██████████████████████████████████████████████████████████████████████████████████| 128/128 [00:01<00:00, 121.97it/s]


[torch.Size([128, 1, 96, 96, 96]), torch.float32]
[torch.Size([128, 5, 96, 96, 96]), torch.float32]
