In [2]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models, transforms, utils
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import scipy.misc
from PIL import Image
import json
%matplotlib inline
import SimpleITK as sitk

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
model = torch.load(r'Z:\grodriguez\CardiacOCT\model_trial_8.pt')
for i in range(len(model)):
  model[i] = model[i].to(device)



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet



In [6]:
model

[ModuleList(
   (0-2): 3 x Sequential(
     (0): StackedConvLayers(
       (blocks): Sequential(
         (0): ConvDropoutNormNonlin(
           (conv): Conv2d(960, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
           (instnorm): InstanceNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
           (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
         )
       )
     )
     (1): StackedConvLayers(
       (blocks): Sequential(
         (0): ConvDropoutNormNonlin(
           (conv): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
           (instnorm): InstanceNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
           (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
         )
       )
     )
   )
   (3): Sequential(
     (0): StackedConvLayers(
       (blocks): Sequential(
         (0): ConvDropoutNormNonlin(
           (conv): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), paddin

In [7]:
model_weights = torch.load(r'Z:\grodriguez\CardiacOCT\data-2d\results\nnUNet\2d\Task508_CardiacOCT\nnUNetTrainer_V2_Loss_CEandDice_Weighted__nnUNetPlansv2.1\fold_0\model_final_checkpoint.model')
state_dict = model_weights['state_dict']

In [10]:
model_params = np.load(r'Z:\grodriguez\CardiacOCT\data-2d\results\nnUNet\2d\Task508_CardiacOCT\nnUNetTrainer_V2_Loss_CEandDice_Weighted__nnUNetPlansv2.1\fold_0\model_final_checkpoint.model.pkl', allow_pickle=True)
model_params['plans']['plans_per_stage']

{0: {'batch_size': 4,
  'num_pool_per_axis': [7, 7],
  'patch_size': array([768, 768], dtype=int64),
  'median_patient_size_in_voxels': array([  1, 691, 691], dtype=int64),
  'current_spacing': array([999.,   1.,   1.]),
  'original_spacing': array([999.,   1.,   1.]),
  'pool_op_kernel_sizes': [[2, 2],
   [2, 2],
   [2, 2],
   [2, 2],
   [2, 2],
   [2, 2],
   [2, 2]],
  'conv_kernel_sizes': [[3, 3],
   [3, 3],
   [3, 3],
   [3, 3],
   [3, 3],
   [3, 3],
   [3, 3],
   [3, 3]],
  'do_dummy_2D_data_aug': False}}

In [12]:
transform = transforms.Compose([
    transforms.ToTensor()
])

image_sample = np.load(r'Z:\grodriguez\CardiacOCT\data-2d\nnUNet_preprocessed\Task508_CardiacOCT\nnUNetData_plans_v2.1_2D_stage0\ESTNEMC0027_1_frame27_001.npz', allow_pickle=True)['data'][:,0,:,:].T
print(image_sample.shape)
image_sample=transform(image_sample)
image_sample=image_sample.unsqueeze(0)
image_sample = image_sample.to(device)
print(image_sample.size())

(691, 691, 22)
torch.Size([1, 22, 691, 691])


In [24]:
conv2d_list = []

for i in range(len(model)):

  for module in model[i].modules():

    if isinstance(module, torch.nn.Conv2d):
      conv2d_list.append(module)

conv2d_list = conv2d_list[:30]

In [25]:
conv2d_list

[Conv2d(960, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(960, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(960, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(1, 32, kernel_size=

In [14]:
del conv2d_list[2:6]

In [26]:
conv2d_list

[Conv2d(960, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(960, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(960, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(1, 32, kernel_size=

In [27]:
weights_list = []

for key, value in state_dict.items():

  if 'conv.weight' in key.lower():

    weights_list.append(value.to(device))


In [30]:
for i in weights_list:
    print(i.shape)

torch.Size([480, 960, 3, 3])
torch.Size([480, 480, 3, 3])
torch.Size([480, 960, 3, 3])
torch.Size([480, 480, 3, 3])
torch.Size([480, 960, 3, 3])
torch.Size([480, 480, 3, 3])
torch.Size([256, 512, 3, 3])
torch.Size([256, 256, 3, 3])
torch.Size([128, 256, 3, 3])
torch.Size([128, 128, 3, 3])
torch.Size([64, 128, 3, 3])
torch.Size([64, 64, 3, 3])
torch.Size([32, 64, 3, 3])
torch.Size([32, 32, 3, 3])
torch.Size([32, 21, 3, 3])
torch.Size([32, 32, 3, 3])
torch.Size([64, 32, 3, 3])
torch.Size([64, 64, 3, 3])
torch.Size([128, 64, 3, 3])
torch.Size([128, 128, 3, 3])
torch.Size([256, 128, 3, 3])
torch.Size([256, 256, 3, 3])
torch.Size([480, 256, 3, 3])
torch.Size([480, 480, 3, 3])
torch.Size([480, 480, 3, 3])
torch.Size([480, 480, 3, 3])
torch.Size([480, 480, 3, 3])
torch.Size([480, 480, 3, 3])
torch.Size([480, 480, 3, 3])
torch.Size([480, 480, 3, 3])


In [31]:
outputs = []
names = []
for layer in conv2d_list:
    image_sample = layer(image_sample)
    outputs.append(image_sample)
    names.append(str(conv2d_list[layer]))
    print('a')

#print feature_maps
for feature_map in outputs:
    print(feature_map.shape)

RuntimeError: Given groups=1, weight of size [480, 960, 3, 3], expected input[1, 22, 691, 691] to have 960 channels, but got 22 channels instead

In [21]:
processed = []
for feature_map in outputs:
    feature_map = feature_map.squeeze(0)
    gray_scale = torch.sum(feature_map,2)
    gray_scale = gray_scale / feature_map.shape[0]
    processed.append(gray_scale.data.cpu().numpy())
for fm in processed:
    print(fm.shape)

In [22]:
fig = plt.figure(figsize=(30, 50))
for i in range(len(processed)):
    a = fig.add_subplot(5, 6, i+1)
    imgplot = plt.imshow(processed[i].astype(np.uint8))
    a.axis("off")
    a.set_title(names[i].split('(')[0], fontsize=30)

<Figure size 2160x3600 with 0 Axes>