In [1]:
import torch
from models.revunet_3D import RevUnet3D, Interpolate
import numpy as np

In [2]:
def get_mod_details(model):
    ret = {'name' : [], 'layer' : []}
    for name, layer in model.named_modules():
        #print(name, layer, type(layer))
        if name != "" and (isinstance(layer, torch.nn.modules.upsampling.Upsample) or isinstance(layer, torch.nn.Conv3d) or isinstance(layer, torch.nn.ReLU) or isinstance(layer, torch.nn.GroupNorm) or isinstance(layer, torch.nn.MaxPool3d)):
            ret['name'].append(name)
            ret['layer'].append(layer)
    return ret

In [3]:
inchan = 1
chanscale = 32
chans = [i//chanscale for i in [64, 128, 256, 512]]
outsize = 12
interp = (512,512,198)
mod = RevUnet3D(inchan, chans, outsize, interp)

layers = get_mod_details(mod)

In [4]:
def get_activations_shapes(layers, x):
    no_rev_shapes = [x.shape]
    with torch.no_grad():
        for n, l in zip(layers['name'], layers['layer']):
            print(n)
            if 'reversible' not in n:
                x = l(x)
                shapes.append(x.shape)
                
    return shapes

def get_activations_shapes_as_dict(layers, x):
    shapes = {'input':x.shape}
    with torch.no_grad():
        for n, l in zip(layers['name'], layers['layer']):
            if 'reversible' not in n:
                x = l(x)
                shapes[n] = x.shape
            else:
                x1, x2 = torch.split(x, x.shape[1]//2, dim = 1)
                y1 = l(x1) + x2
                shapes[n] = y1.shape
    return shapes

In [5]:
x = torch.from_numpy(np.random.rand(1,1,256,256,99)).float()
acts = get_activations_shapes_as_dict(layers, x)

  "See the documentation of nn.Upsample for details.".format(mode))


In [6]:
def memory_cosumption(shapes, float_types='float'):
    m = 0
    all_types = {'float':4, 'double':8, 'half':2}
    for s in shapes:
        m += np.prod(s)*all_types[float_types]
        print(s, np.prod(s)*all_types[float_types])
    return m

def forward_memory_cosumption_with_peak(shapes, float_types='float'):
    cur_m = 0
    max_m = 0
    all_types = {'float':4, 'double':8, 'half':2}
    
    enc, dec = -1, -1
    for k in list(shapes.keys()):
        s = shapes[k]
        if 'reversible' not in k:
            cur_m += np.prod(s)*all_types[float_types]
            max_m += np.prod(s)*all_types[float_types]
        elif 'encoders' in k:
            tmp_enc = int(k[9])
            if tmp_enc == enc:
                tmp_max_m += np.prod(s)*all_types[float_types]
                if max_m < tmp_max_m:
                    max_m = tmp_max_m
            else:
                enc = tmp_enc
                tmp_max_m = cur_m + np.prod(s)*all_types[float_types]
                if max_m < tmp_max_m:
                    max_m = tmp_max_m
            
                
        
        #print(s, np.prod(s)*all_types[float_types])
    return cur_m, max_m

In [9]:
def flt(t):
    return {'float':4, 'double':8, 'half':2}[t]

def model_memory(mod, floatt = 'float'):
    return sum(p.numel()*flt(floatt) for p in mod.parameters())

In [7]:
def convert_byte(v):
    units = {'Bytes':1,'KB':1e-3, 'MB':1e-6, 'GB':1e-9}
    tmp = 'Bytes'
    for k in list(units.keys()):
        if int(v*units[k]) == 0:
            return v*units[tmp], tmp
        tmp = k
    return v*units[tmp], tmp

In [8]:
cur_m, max_m = memory_cosumption_with_peak(acts)
print(convert_byte(cur_m))
print(convert_byte(max_m))

(2.5949921280000003, 'GB')
(2.7507056640000003, 'GB')


In [11]:
convert_byte(model_memory(mod))

(154.464, 'KB')

In [12]:
def labels_mem(y,floatt = 'float'):
    return np.prod(y.shape)*flt(floatt)
    

In [13]:
y = torch.from_numpy(np.random.rand(1,12,512,512,198)).float()

In [14]:
print(convert_byte(labels_mem(y)))

(2.491416576, 'GB')


In [22]:
a = torch.from_numpy(np.random.rand(1,2,5,5)).float()
print(a)
print(a.element_size())
print(torch.argmax(a,1))
print(torch.argmax(a,1).short().element_size())

tensor([[[[0.7194, 0.5986, 0.1094, 0.5060, 0.2073],
          [0.1055, 0.8037, 0.7823, 0.4362, 0.9706],
          [0.4120, 0.6327, 0.2057, 0.3950, 0.2494],
          [0.3431, 0.4298, 0.3011, 0.4617, 0.9267],
          [0.9172, 0.5996, 0.5468, 0.7262, 0.8042]],

         [[0.7219, 0.8398, 0.0785, 0.7518, 0.7601],
          [0.5366, 0.5807, 0.0706, 0.1895, 0.7764],
          [0.7727, 0.6994, 0.3189, 0.9119, 0.2093],
          [0.5298, 0.3085, 0.1669, 0.7628, 0.6441],
          [0.7853, 0.8312, 0.9386, 0.2036, 0.5099]]]])
4
tensor([[[1, 1, 0, 1, 1],
         [1, 0, 0, 0, 0],
         [1, 1, 1, 1, 0],
         [1, 0, 0, 1, 0],
         [0, 1, 1, 0, 0]]])
2


In [28]:
unet = [0.703,0.694,0.652,0.666,0.698,0.660]
rev = [0.698, 0.722, 0.697, 0.702, 0.706, 0.696]

print(np.mean(unet))
print(np.std(unet))

0.6788333333333333
0.020086617988656547


In [9]:
convert_byte(np.prod(x.shape)*4)

(103.809024, 'MB')

In [10]:
print((torch.cuda.max_memory_allocated()), ((torch.cuda.memory_allocated())))
x = x.cuda().detach()
print((torch.cuda.max_memory_allocated()), ((torch.cuda.memory_allocated())))

0 0
104857600 104857600


In [15]:
L1 = layers['layer'][0].cuda()
L2 = layers['layer'][1].cuda()
print((torch.cuda.max_memory_allocated()), ((torch.cuda.memory_allocated())))
x = L1(x)
x = L2(x)
print((torch.cuda.max_memory_allocated()), ((torch.cuda.memory_allocated())))

430074880 312477184


RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [1] and input of shape [4, 2, 256, 256, 99]

In [12]:
103809024 + 207618048

311427072

In [13]:
layers['layer'][0]

Conv3d(1, 2, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)