In [13]:

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

In [14]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

class SizeEstimator(object):

    def __init__(self, model, input_size=(1,1,32,32), bits=32):
        '''
        Estimates the size of PyTorch models in memory
        for a given input size
        '''
        self.model = model
        self.input_size = input_size
        self.bits = 32
        return

    def get_parameter_sizes(self):
        '''Get sizes of all parameters in `model`'''
        mods = list(model.modules())
        for i in range(1,len(mods)):
            m = mods[i]
            p = list(m.parameters())
            sizes = []
            for j in range(len(p)):
                sizes.append(np.array(p[j].size()))

        self.param_sizes = sizes
        return

    def get_output_sizes(self):
        '''Run sample input through each layer to get output sizes'''
        input_ = Variable(torch.FloatTensor(*self.input_size), volatile=True)
        mods = list(model.modules())
        out_sizes = []
        for i in range(1, len(mods)):
            m = mods[i]
            out = m(input_)
            out_sizes.append(np.array(out.size()))
            input_ = out

        self.out_sizes = out_sizes
        return

    def calc_param_bits(self):
        '''Calculate total number of bits to store `model` parameters'''
        total_bits = 0
        for i in range(len(self.param_sizes)):
            s = self.param_sizes[i]
            bits = np.prod(np.array(s))*self.bits
            total_bits += bits
        self.param_bits = total_bits
        return

    def calc_forward_backward_bits(self):
        '''Calculate bits to store forward and backward pass'''
        total_bits = 0
        for i in range(len(self.out_sizes)):
            s = self.out_sizes[i]
            bits = np.prod(np.array(s))*self.bits
            total_bits += bits
        # multiply by 2 for both forward AND backward
        self.forward_backward_bits = (total_bits*2)
        return

    def calc_input_bits(self):
        '''Calculate bits to store input'''
        self.input_bits = np.prod(np.array(self.input_size))*self.bits
        return

    def estimate_size(self):
        '''Estimate model size in memory in megabytes and bits'''
        self.get_parameter_sizes()
        self.get_output_sizes()
        self.calc_param_bits()
        self.calc_forward_backward_bits()
        self.calc_input_bits()
        total = self.param_bits + self.forward_backward_bits + self.input_bits

        total_megabytes = (total/8)/(1024**2)
        return total_megabytes, total

# Teacher

In [15]:
class Teacher(nn.Module):

    def __init__(self, n_features, n_classes):
        super(Teacher, self).__init__()

        self.mlp = nn.Sequential(*[
            nn.Linear(n_features, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, n_classes)
        ])

    def forward(self, x):
        return self.mlp(x)

In [16]:
model = Teacher(43, 2)
model

Teacher(
  (mlp): Sequential(
    (0): Linear(in_features=43, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=16, bias=True)
    (3): ReLU()
    (4): Linear(in_features=16, out_features=8, bias=True)
    (5): ReLU()
    (6): Linear(in_features=8, out_features=2, bias=True)
  )
)

In [17]:
sum(p.numel() for p in model.parameters())

2090

In [25]:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
mem

8360

# Student 1

In [28]:
class Student1(nn.Module):

    def __init__(self, n_features, n_classes):
        super(Student1, self).__init__()

        self.mlp = nn.Sequential(*[
            nn.Linear(n_features, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 4),
            nn.ReLU(),
            nn.Linear(4, n_classes)
        ])

    def forward(self, x):
        return self.mlp(x)

In [29]:
model = Student1(43, 2)
model

Student1(
  (mlp): Sequential(
    (0): Linear(in_features=43, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=8, bias=True)
    (3): ReLU()
    (4): Linear(in_features=8, out_features=4, bias=True)
    (5): ReLU()
    (6): Linear(in_features=4, out_features=2, bias=True)
  )
)

In [30]:
sum(p.numel() for p in model.parameters())

886

In [32]:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
mem

3544

# Student 2

In [33]:
class Student2(nn.Module):

    def __init__(self, n_features, n_classes):
        super(Student2, self).__init__()

        self.mlp = nn.Sequential(*[
            nn.Linear(n_features, 16),
            nn.ReLU(),
            nn.Linear(16, n_classes)
        ])

    def forward(self, x):
        return self.mlp(x)

In [35]:
model = Student2(43, 2)
model

Student2(
  (mlp): Sequential(
    (0): Linear(in_features=43, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=2, bias=True)
  )
)

In [36]:
sum(p.numel() for p in model.parameters())

738

In [37]:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
mem

2952

# Student 3

In [40]:
class Student3(nn.Module):

    def __init__(self, n_features, n_classes):
        super(Student3, self).__init__()

        self.mlp = nn.Sequential(*[
            nn.Linear(n_features, 8),
            nn.ReLU(),
            nn.Linear(8, n_classes)
        ])

    def forward(self, x):
        return self.mlp(x)

In [41]:
model = Student3(43, 2)
model

Student3(
  (mlp): Sequential(
    (0): Linear(in_features=43, out_features=8, bias=True)
    (1): ReLU()
    (2): Linear(in_features=8, out_features=2, bias=True)
  )
)

In [42]:
sum(p.numel() for p in model.parameters())

370

In [43]:
mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes
mem

1480