In [1]:
import sys

sys.path.append('/home/tuomas/Desktop/GITS/mCTSegmentation/mctseg/unet/')

import numpy as np
import os
import h5py
import time
import gc
import copy

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.functional as F

from model import UNet

In [2]:
def fuse_bn_sequential(block):
    """
    This function takes a sequential block and fuses the batch normalization with convolution
    :param model: nn.Sequential. Source resnet model
    :return: nn.Sequential. Converted block
    """
    print(block)
    if not isinstance(block, nn.Sequential):
        return block
    stack = []
    for m in block.children():
        if isinstance(m, nn.BatchNorm2d):
            if isinstance(stack[-1], nn.Conv2d):
                bn_st_dict = m.state_dict()
                conv_st_dict = stack[-1].state_dict()

                # BatchNorm params
                eps = m.eps
                mu = bn_st_dict['running_mean']
                var = bn_st_dict['running_var']
                gamma = bn_st_dict['weight']

                if 'bias' in bn_st_dict:
                    beta = bn_st_dict['bias']
                else:
                    beta = torch.zeros(gamma.size(0)).float().to(gamma.device)

                # Conv params
                W = conv_st_dict['weight']
                if 'bias' in conv_st_dict:
                    bias = conv_st_dict['bias']
                else:
                    bias = torch.zeros(W.size(0)).float().to(gamma.device)

                denom = torch.sqrt(var + eps)
                b = beta - gamma.mul(mu).div(denom)
                A = gamma.div(denom)
                bias *= A
                A = A.expand_as(W.transpose(0, -1)).transpose(0, -1)

                W.mul_(A)
                bias.add_(b)

                stack[-1].weight.data.copy_(W)
                if stack[-1].bias is None:
                    stack[-1].bias = torch.nn.Parameter(bias)
                else:
                    stack[-1].bias.data.copy_(bias)

        else:
            stack.append(m)
            

    if len(stack) > 1:
        return nn.Sequential(*stack)
    else:
        return stack[0]


def fuse_bn_recursively(model):
    for module_name in model._modules:
        model._modules[module_name] = fuse_bn_sequential(model._modules[module_name])
        if len(model._modules[module_name]._modules) > 0:
            fuse_bn_recursively(model._modules[module_name])
            if len(model._modules[module_name]._modules) > 0:
                fuse_bn_recursively(model._modules[module_name])

    return model

In [3]:
path = '/home/tuomas/Desktop/GITS/snapshots/2018_12_03_15_25'
files = os.listdir(path)
files.sort()

names = []
nets = []

for file in files:
    print(file) 
    f = file.split('_')
    if f[0] == 'fold':
        name = f[0]+'_'+f[1]+'_'+f[2]+'_'+f[3][:-4]
        names.append(name)
        net1 = UNet(24,6,2);
        net1.load_state_dict(torch.load(os.path.join(path,file)));
        net2 = fuse_bn_recursively(net1);
        #net2._modules['center'] = fuse_bn_sequential(net2._modules['center'])
        
        nets.append(copy.deepcopy(net2));
        
        net2 = None
        net1 = None
        gc.collect()
        


fold_0_epoch_26.pth
Encoder(
  (layers): Sequential(
    (conv_3x3_0): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (conv_3x3_1): Sequential(
      (0): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
  )
)
Sequential(
  (conv_3x3_0): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
    )
  )
  (conv_3x3_1): Sequential(
    (0): Sequential(
      (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2

fold_1_epoch_31.pth
Encoder(
  (layers): Sequential(
    (conv_3x3_0): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (conv_3x3_1): Sequential(
      (0): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
  )
)
Sequential(
  (conv_3x3_0): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
    )
  )
  (conv_3x3_1): Sequential(
    (0): Sequential(
      (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2

fold_2_epoch_24.pth
Encoder(
  (layers): Sequential(
    (conv_3x3_0): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (conv_3x3_1): Sequential(
      (0): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
  )
)
Sequential(
  (conv_3x3_0): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
    )
  )
  (conv_3x3_1): Sequential(
    (0): Sequential(
      (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2

fold_3_epoch_20.pth
Encoder(
  (layers): Sequential(
    (conv_3x3_0): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (conv_3x3_1): Sequential(
      (0): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
  )
)
Sequential(
  (conv_3x3_0): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
    )
  )
  (conv_3x3_1): Sequential(
    (0): Sequential(
      (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2

fold_4_epoch_48.pth
Encoder(
  (layers): Sequential(
    (conv_3x3_0): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (conv_3x3_1): Sequential(
      (0): Sequential(
        (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
  )
)
Sequential(
  (conv_3x3_0): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
    )
  )
  (conv_3x3_1): Sequential(
    (0): Sequential(
      (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2

session.pkl


In [4]:
for key,w in nets[0].state_dict().items():
    print(key)
    print(w.shape)

down1.layers.0.0.weight
torch.Size([24, 1, 3, 3])
down1.layers.0.0.bias
torch.Size([24])
down1.layers.1.0.weight
torch.Size([24, 24, 3, 3])
down1.layers.1.0.bias
torch.Size([24])
down2.layers.0.0.weight
torch.Size([48, 24, 3, 3])
down2.layers.0.0.bias
torch.Size([48])
down2.layers.1.0.weight
torch.Size([48, 48, 3, 3])
down2.layers.1.0.bias
torch.Size([48])
down3.layers.0.0.weight
torch.Size([96, 48, 3, 3])
down3.layers.0.0.bias
torch.Size([96])
down3.layers.1.0.weight
torch.Size([96, 96, 3, 3])
down3.layers.1.0.bias
torch.Size([96])
down4.layers.0.0.weight
torch.Size([192, 96, 3, 3])
down4.layers.0.0.bias
torch.Size([192])
down4.layers.1.0.weight
torch.Size([192, 192, 3, 3])
down4.layers.1.0.bias
torch.Size([192])
down5.layers.0.0.weight
torch.Size([384, 192, 3, 3])
down5.layers.0.0.bias
torch.Size([384])
down5.layers.1.0.weight
torch.Size([384, 384, 3, 3])
down5.layers.1.0.bias
torch.Size([384])
down6.layers.0.0.weight
torch.Size([768, 384, 3, 3])
down6.layers.0.0.bias
torch.Size([768

In [5]:
savename = 'UNetE3New.h5'
#h5 = h5py.File(savename,'w')

for name,net in zip(names,nets):
    W = OrderedDict()
    for key,w in net.state_dict().items():
        _name = key[:]
        W[_name] = w
    savename = '/media/tuomas/data/WRKNew/UNet_'+name[:6]+'_new.h5';
    print(savename)
    h5 = h5py.File(savename,'w')
    for key in W:
        if key.endswith('0.bias') or key.endswith('0.weight'):
            N = key.split('.')
            #print(key)
            if N[1] == 'layers':
                npart = N[2]
                npart = npart.split('_')
                dname = N[0]+'_'+npart[-1]+'_'+N[-1]
                #print(dname)
            else:
                npart = N[1]
                npart = npart.split('_')
                dname = N[0]+'_'+npart[-1]+'_'+N[-1]
                #print(dname)
            print(dname)
            T = W[key]
            #T = T.view(T.numel())
            T = T.cpu().numpy().astype('float64')
            #D = T.shape
            #if len(D) == 4:
            #    T = np.swapaxes(T,3,2)
            T = T.flatten()
            #print(T)    
            h5.create_dataset(dname,data=T)
        if key.startswith('mixer'):
            dname = key            
            T = W[key]
            T = T.cpu().numpy().astype('float64')
            T = T.flatten()
            print(dname)
            print(T.shape)
            h5.create_dataset(dname,data=T)
            
    h5.close()
    

/media/tuomas/data/WRKNew/UNet_fold_0_new.h5
down1_0_weight
down1_0_bias
down1_1_weight
down1_1_bias
down2_0_weight
down2_0_bias
down2_1_weight
down2_1_bias
down3_0_weight
down3_0_bias
down3_1_weight
down3_1_bias
down4_0_weight
down4_0_bias
down4_1_weight
down4_1_bias
down5_0_weight
down5_0_bias
down5_1_weight
down5_1_bias
down6_0_weight
down6_0_bias
down6_1_weight
down6_1_bias
center_0_weight
center_0_bias
center_1_weight
center_1_bias
up6_0_weight
up6_0_bias
up6_1_weight
up6_1_bias
up5_0_weight
up5_0_bias
up5_1_weight
up5_1_bias
up4_0_weight
up4_0_bias
up4_1_weight
up4_1_bias
up3_0_weight
up3_0_bias
up3_1_weight
up3_1_bias
up2_0_weight
up2_0_bias
up2_1_weight
up2_1_bias
up1_0_weight
up1_0_bias
up1_1_weight
up1_1_bias
mixer.weight
(24,)
mixer.bias
(1,)
/media/tuomas/data/WRKNew/UNet_fold_1_new.h5
down1_0_weight
down1_0_bias
down1_1_weight
down1_1_bias
down2_0_weight
down2_0_bias
down2_1_weight
down2_1_bias
down3_0_weight
down3_0_bias
down3_1_weight
down3_1_bias
down4_0_weight
down4_0_