In [1]:
import os,datetime
import imageio
import h5py
import torch
import torch.nn as nn
import numpy as np
from threedunet import unet_residual_3d
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def read_h5(filename, dataset=''):
    fid = h5py.File(filename, 'r')
    if dataset == '':
        dataset = list(fid)[0]
    return np.array(fid[dataset])

In [2]:
model= unet_residual_3d(in_channel=1, out_channel=13).to(device)

In [3]:
summary(model,(1,112,112,112))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
  ReplicationPad3d-1     [-1, 1, 112, 116, 116]               0
            Conv3d-2    [-1, 28, 112, 112, 112]             728
SynchronizedBatchNorm3d-3    [-1, 28, 112, 112, 112]              56
               ELU-4    [-1, 28, 112, 112, 112]               0
  ReplicationPad3d-5    [-1, 28, 112, 114, 114]               0
            Conv3d-6    [-1, 28, 112, 112, 112]           7,084
SynchronizedBatchNorm3d-7    [-1, 28, 112, 112, 112]              56
               ELU-8    [-1, 28, 112, 112, 112]               0
  ReplicationPad3d-9    [-1, 28, 112, 114, 114]               0
           Conv3d-10    [-1, 28, 112, 112, 112]           7,084
SynchronizedBatchNorm3d-11    [-1, 28, 112, 112, 112]              56
              ELU-12    [-1, 28, 112, 112, 112]               0
 ReplicationPad3d-13    [-1, 28, 112, 114, 114]               0
           Conv3d-14   

In [4]:
model = nn.DataParallel(model, device_ids=range(4))
#patch_replication_callback(model)
model = model.to(device)

    There is an imbalance between your GPUs. You may want to exclude GPU 1 which
    has less than 75% of the memory or cores of GPU 0. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.


In [5]:
checkpoint = 'checkpoint_50000.pth.tar'

In [6]:
# load pre-trained model
print('Load pretrained checkpoint: ', checkpoint)
checkpoint = torch.load(checkpoint)
print('checkpoints: ', checkpoint.keys())

Load pretrained checkpoint:  checkpoint_50000.pth.tar
checkpoints:  dict_keys(['iteration', 'state_dict', 'optimizer', 'lr_scheduler'])


In [7]:
# update model weights
if 'state_dict' in checkpoint.keys():
    pretrained_dict = checkpoint['state_dict']
    model_dict = model.module.state_dict() # nn.DataParallel
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict 
    model_dict.update(pretrained_dict)    
    # 3. load the new state dict
    model.module.load_state_dict(model_dict) # nn.DataParallel
    
    print("new state dict loaded ")
    
model.eval()

new state dict loaded 


DataParallel(
  (module): unet_residual_3d(
    (downE): Sequential(
      (0): Sequential(
        (0): ReplicationPad3d((2, 2, 2, 2, 0, 0))
        (1): Conv3d(1, 28, kernel_size=(1, 5, 5), stride=(1, 1, 1))
        (2): SynchronizedBatchNorm3d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ELU(alpha=1.0, inplace=True)
      )
      (1): Sequential(
        (0): ReplicationPad3d((1, 1, 1, 1, 0, 0))
        (1): Conv3d(28, 28, kernel_size=(1, 3, 3), stride=(1, 1, 1))
        (2): SynchronizedBatchNorm3d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ELU(alpha=1.0, inplace=True)
      )
      (2): residual_block_2d(
        (conv): Sequential(
          (0): Sequential(
            (0): ReplicationPad3d((1, 1, 1, 1, 0, 0))
            (1): Conv3d(28, 28, kernel_size=(1, 3, 3), stride=(1, 1, 1))
            (2): SynchronizedBatchNorm3d(28, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (3): 

In [8]:
volume_name = 'MvsFCSmale4week-3-DownSamp_im.h5'

image_volume = read_h5(volume_name)   #reading CT volume 

print(image_volume.shape)
vol = image_volume

volume = torch.from_numpy(vol).to(device, dtype=torch.float)
volume = volume.unsqueeze(0)

volume = volume.unsqueeze(0)

print(volume.shape)

(112, 112, 112)
torch.Size([1, 1, 112, 112, 112])


In [9]:
pred = model(volume)
pred = pred.squeeze(0)

In [10]:
pred = pred.cpu()

In [11]:
pred_final = np.argmax(pred.detach().numpy(),axis=0).astype(np.uint16)
print("Shape of Predictions after argmax() function ", pred_final.shape)

Shape of Predictions after argmax() function  (112, 112, 112)


In [12]:
hf1 = h5py.File('MvsFCSmale4week-3-DownSamp_pred.h5', 'w')
hf1.create_dataset('dataset1', data=pred_final)
print("Prediction volume created and saved" , hf1)
hf1.close()

Prediction volume created and saved <HDF5 file "MvsFCSmale4week-3-DownSamp_pred.h5" (mode r+)>
