In [1]:
import torch
from models.unet_3D import unet_3D
from models.unetUtils import Pad
import numpy as np
import models

In [2]:
def get_mod_details(model):
    ret = {'name' : [], 'layer' : []}
    for name, layer in model.named_modules():
        #print(name, layer, type(layer))
        #print(name, layer)
        if name != "" and (isinstance(layer, torch.nn.modules.upsampling.Upsample) or isinstance(layer, torch.nn.Conv3d) or isinstance(layer, torch.nn.ReLU) or isinstance(layer, torch.nn.GroupNorm) or isinstance(layer, torch.nn.MaxPool3d) or isinstance(layer, models.unetUtils.Pad)):
            ret['name'].append(name)
            ret['layer'].append(layer)
    return ret

In [3]:
inchan = 1
chanscale = 32
chans = [i//chanscale for i in [64, 128, 256, 512, 1024]]
outsize = 12
interp = (512,512,198)
mod = unet_3D(chans, n_classes=outsize, in_channels=inchan, interpolation = interp)

layers = get_mod_details(mod)

  init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
  init.normal(m.weight.data, 1.0, 0.02)
  init.constant(m.bias.data, 0.0)


In [4]:
layers

{'name': ['conv1.conv1',
  'conv1.relu1',
  'conv1.conv2',
  'conv1.relu2',
  'maxpool1',
  'conv2.conv1',
  'conv2.relu1',
  'conv2.conv2',
  'conv2.relu2',
  'maxpool2',
  'conv3.conv1',
  'conv3.relu1',
  'conv3.conv2',
  'conv3.relu2',
  'maxpool3',
  'conv4.conv1',
  'conv4.relu1',
  'conv4.conv2',
  'conv4.relu2',
  'maxpool4',
  'center.conv1',
  'center.relu1',
  'center.conv2',
  'center.relu2',
  'up_concat4.up',
  'up_concat4.pad',
  'up_concat4.conv.conv1',
  'up_concat4.conv.relu1',
  'up_concat4.conv.conv2',
  'up_concat4.conv.relu2',
  'up_concat3.up',
  'up_concat3.pad',
  'up_concat3.conv.conv1',
  'up_concat3.conv.relu1',
  'up_concat3.conv.conv2',
  'up_concat3.conv.relu2',
  'up_concat2.up',
  'up_concat2.pad',
  'up_concat2.conv.conv1',
  'up_concat2.conv.relu1',
  'up_concat2.conv.conv2',
  'up_concat2.conv.relu2',
  'up_concat1.up',
  'up_concat1.pad',
  'up_concat1.conv.conv1',
  'up_concat1.conv.relu1',
  'up_concat1.conv.conv2',
  'up_concat1.conv.relu2',
  'f

In [11]:
def get_activations_shapes(layers, x):
    no_rev_shapes = [x.shape]
    with torch.no_grad():
        for n, l in zip(layers['name'], layers['layer']):
            print(n)
            x = l(x)
            shapes.append(x.shape)
                
    return shapes

def get_activations_shapes_as_dict(layers, x):
    shapes = {'input':x.shape}
    tmp = []
    up = -1
    with torch.no_grad():
        for n, l in zip(layers['name'], layers['layer']):
            print(n)
            if 'max' in n:
                tmp.append(x)
                x = l(x)
                shapes[n] = x.shape
            elif 'pad' in n:
                inputs1 = tmp[up]
                offset = x.size()[2] - inputs1.size()[2]
                padding = 2 * [offset // 2, offset // 2, 0]
                x, y = l(inputs1, x, padding)
                x = torch.cat([x,y],1)
                up -= 1    
            else:
                x = l(x)
                shapes[n] = x.shape
    return shapes

In [12]:
x = torch.from_numpy(np.random.rand(1,1,256,256,99)).float()
acts = get_activations_shapes_as_dict(layers, x)

conv1.conv1
conv1.relu1
conv1.conv2
conv1.relu2
maxpool1
conv2.conv1
conv2.relu1
conv2.conv2
conv2.relu2
maxpool2
conv3.conv1
conv3.relu1
conv3.conv2
conv3.relu2
maxpool3
conv4.conv1
conv4.relu1
conv4.conv2
conv4.relu2
maxpool4
center.conv1
center.relu1
center.conv2
center.relu2
up_concat4.up
up_concat4.pad
up_concat4.conv.conv1
up_concat4.conv.relu1
up_concat4.conv.conv2
up_concat4.conv.relu2
up_concat3.up
up_concat3.pad
up_concat3.conv.conv1
up_concat3.conv.relu1
up_concat3.conv.conv2
up_concat3.conv.relu2
up_concat2.up
up_concat2.pad
up_concat2.conv.conv1
up_concat2.conv.relu1
up_concat2.conv.conv2
up_concat2.conv.relu2
up_concat1.up
up_concat1.pad
up_concat1.conv.conv1
up_concat1.conv.relu1
up_concat1.conv.conv2
up_concat1.conv.relu2
final
interpolation


In [13]:
def memory_cosumption(shapes, float_types='float'):
    m = 0
    all_types = {'float':4, 'double':8, 'half':2}
    for s in shapes:
        m += np.prod(s)*all_types[float_types]
        print(s, np.prod(s)*all_types[float_types])
    return m

def forward_memory_cosumption_with_peak(shapes, float_types='float'):
    cur_m = 0
    all_types = {'float':4, 'double':8, 'half':2}
    
    enc, dec = -1, -1
    for k in list(shapes.keys()):
        s = shapes[k]
        cur_m += np.prod(s)*all_types[float_types]

    return cur_m

In [14]:
def flt(t):
    return {'float':4, 'double':8, 'half':2}[t]

def model_memory(mod, floatt = 'float'):
    return sum(p.numel()*flt(floatt) for p in mod.parameters())

In [15]:
def convert_byte(v):
    units = {'Bytes':1,'KB':1e-3, 'MB':1e-6, 'GB':1e-9}
    tmp = 'Bytes'
    for k in list(units.keys()):
        if int(v*units[k]) == 0:
            return v*units[tmp], tmp
        tmp = k
    return v*units[tmp], tmp

In [17]:
cur_m = forward_memory_cosumption_with_peak(acts)
print(convert_byte(cur_m))

(3.5233136640000002, 'GB')


In [11]:
convert_byte(model_memory(mod))

(154.464, 'KB')

In [12]:
def labels_mem(y,floatt = 'float'):
    return np.prod(y.shape)*flt(floatt)
    

In [13]:
y = torch.from_numpy(np.random.rand(1,12,512,512,198)).float()

In [14]:
print(convert_byte(labels_mem(y)))

(2.491416576, 'GB')


In [22]:
a = torch.from_numpy(np.random.rand(1,2,5,5)).float()
print(a)
print(a.element_size())
print(torch.argmax(a,1))
print(torch.argmax(a,1).short().element_size())

tensor([[[[0.7194, 0.5986, 0.1094, 0.5060, 0.2073],
          [0.1055, 0.8037, 0.7823, 0.4362, 0.9706],
          [0.4120, 0.6327, 0.2057, 0.3950, 0.2494],
          [0.3431, 0.4298, 0.3011, 0.4617, 0.9267],
          [0.9172, 0.5996, 0.5468, 0.7262, 0.8042]],

         [[0.7219, 0.8398, 0.0785, 0.7518, 0.7601],
          [0.5366, 0.5807, 0.0706, 0.1895, 0.7764],
          [0.7727, 0.6994, 0.3189, 0.9119, 0.2093],
          [0.5298, 0.3085, 0.1669, 0.7628, 0.6441],
          [0.7853, 0.8312, 0.9386, 0.2036, 0.5099]]]])
4
tensor([[[1, 1, 0, 1, 1],
         [1, 0, 0, 0, 0],
         [1, 1, 1, 1, 0],
         [1, 0, 0, 1, 0],
         [0, 1, 1, 0, 0]]])
2


In [28]:
unet = [0.703,0.694,0.652,0.666,0.698,0.660]
rev = [0.698, 0.722, 0.697, 0.702, 0.706, 0.696]

print(np.mean(unet))
print(np.std(unet))

0.6788333333333333
0.020086617988656547


In [9]:
convert_byte(np.prod(x.shape)*4)

(103.809024, 'MB')

In [10]:
print((torch.cuda.max_memory_allocated()), ((torch.cuda.memory_allocated())))
x = x.cuda().detach()
print((torch.cuda.max_memory_allocated()), ((torch.cuda.memory_allocated())))

0 0
104857600 104857600


In [15]:
L1 = layers['layer'][0].cuda()
L2 = layers['layer'][1].cuda()
print((torch.cuda.max_memory_allocated()), ((torch.cuda.memory_allocated())))
x = L1(x)
x = L2(x)
print((torch.cuda.max_memory_allocated()), ((torch.cuda.memory_allocated())))

430074880 312477184


RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [1] and input of shape [4, 2, 256, 256, 99]

In [12]:
103809024 + 207618048

311427072

In [13]:
layers['layer'][0]

Conv3d(1, 2, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)