In [1]:
#export
import torch
import torch.nn as nn
import fastcore
from fastai.vision.all import *
from torchvision.models import resnet18
from torchvision.models import inception_v3
from torchvision.models import googlenet

In [2]:
#export
class AdaptiveConcatPool3d(nn.Module):
    def __init__(self, size=None):
        super().__init__()
        size = size or (1,1,1)
        self.ap = nn.AdaptiveAvgPool3d(size)
        self.mp = nn.AdaptiveMaxPool3d(size)
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

In [3]:
#export
def get_hps(module):
    # Getting hyper parameter names
    hp_names = type(module).__init__.__code__.co_varnames[1:] # little trick that gets me the names of the 
                                                              # inputs of the init function 
    # Creating hyper parameter dict and Inflating tuple hps (kernel_size, padding, etc.)
    hps = {}
    for k in hp_names:
            v = getattr(module, k)
            hps[k] = ((v[0]+v[1])//2, *v) if isinstance(v, tuple) else v
    return hps

In [4]:
#export
@typedispatch
def inflate(c2d:nn.Conv2d):
    hps = get_hps(c2d)
    hps['bias'] = not hps['bias'] is None
        
    # Inflating the 2d params and storing them in state dict
    c3d = nn.Conv3d(**hps)
    sd = {'weight':c2d.weight.unsqueeze(2).expand(*c3d.weight.shape)}
    if hps['bias']: sd['bias'] = c2d.bias     
        
    c3d.load_state_dict(sd, strict=False)
    return c3d

@typedispatch
def inflate(bn2d:nn.BatchNorm2d):
    bn3d = nn.BatchNorm3d(**get_hps(bn2d))
    bn3d.load_state_dict(bn2d.state_dict())
    return bn3d

@typedispatch
def inflate(do2d:nn.Dropout2d):
    p, inplace = do2d.p, do2d.inplace
    return nn.Dropout3d(p, inplace)

@typedispatch
def inflate(m:nn.MaxPool2d): return nn.MaxPool3d(**get_hps(m))
    
@typedispatch
def inflate(m:nn.AvgPool2d): return nn.AvgPool3d(**get_hps(m))    

@typedispatch
def inflate(m:AdaptiveConcatPool2d): return AdaptiveConcatPool3d(**get_hps(m))
    
@typedispatch
def inflate(m:nn.AdaptiveAvgPool2d): return nn.AdaptiveAvgPool3d(**get_hps(m))
    
@typedispatch
def inflate(m:nn.AdaptiveMaxPool2d): return nn.AdaptiveMaxPool3d(**get_hps(m))    
    
@typedispatch
def inflate(m:nn.Module):
    for name, child in m.named_children():
        setattr(m, name, inflate(child))
    return m  

%nbdev_add2all inflate

In [5]:
bn2d = nn.BatchNorm2d(5)
bn3d = inflate(bn2d)

c2d = nn.Conv2d(3, 2, kernel_size=(3,3))
c3d = inflate(c2d)

do2d = nn.Dropout2d()
do3d = inflate(do2d)

do2d, do3d, c2d, c3d, bn2d, bn3d

(Dropout2d(p=0.5, inplace=False),
 Dropout3d(p=0.5, inplace=False),
 Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1)),
 Conv3d(3, 2, kernel_size=(3, 3, 3), stride=(1, 1, 1)),
 BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))

In [6]:
m = inflate(googlenet()).eval()
m



GoogLeNet(
  (conv1): BasicConv2d(
    (conv): Conv3d(3, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
    (bn): BatchNorm3d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool3d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (conv2): BasicConv2d(
    (conv): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (bn): BatchNorm3d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): BasicConv2d(
    (conv): Conv3d(64, 192, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (bn): BatchNorm3d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool2): MaxPool3d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (inception3a): Inception(
    (branch1): BasicConv2d(
      (conv): Conv3d(192, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
      (bn): BatchNorm3d(64, eps=0.001, mo

In [6]:
x = torch.randn(10,3,100,128,128)
m(x).shape

torch.Size([10, 1000])

In [6]:
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted 01_triplet_loss.ipynb.
Converted 02_inflator.ipynb.
Converted 03_video_block.ipynb.
Converted 04_datasets.ipynb.
Converted 05_first_inflated_NN.ipynb.
Converted index.ipynb.


In [7]:
#default_exp inflator