In [None]:
# default_exp core

# Core functionality

> Contains module definitions and helper functions

In [None]:
#hide
from nbdev.showdoc import *
import torch
import torch.nn as nn
from functools import partial
import torchvision.models as models
from fastai.vision import *

# Utility functions

In [None]:
#export
def print_all(*args, **kwargs):
    "Prints all arguments passed in separate lines"
    for arg in args: print(arg)
    for key in kwargs.keys(): print(f'{key} : {kwargs[key]}')

In [None]:
print_all('one', 'two', arg1='three', arg2='four')

one
two
arg1 : three
arg2 : four


In [None]:
#export
def children(m:nn.Module):
    "get a list of children modules of `m`"
    return list(m.children())

In [None]:
#export
def recursively_apply_to_children(m, f):
    "Apply `f` recursively to all the children of `m`"
    kids = children(m)
    if isinstance(m, nn.Module): f(m)
    for child in kids: recursively_apply_to_children(child, f)

In [None]:
#export
def bn_to_tanh(m):
    if isinstance(m, nn.BatchNorm2d):
        nf = m.num_features
        m = TanHNorm(nf)
    return m

# NN Layers / Modules

In [None]:
#hide
class AdjustModule(nn.Module):
    # TODO
    # Adds mults and adds to any nn.Module passed
    def __init__(self):
        raise NotImplementedError("This module has not yet been implemented")

In [None]:
#export
class AdjustNormFunc(nn.Module):
    "Creates a BatchNorm-like module using func : x = func(x) * scale + shift"
    def __init__(self, nf, func=torch.tanh, name=None):
        super().__init__()
        self.func = func
        self.name = name
        self.nf = nf
        self.scale = nn.Parameter(torch.ones (nf, 1, 1))
        self.shift = nn.Parameter(torch.zeros(nf, 1, 1))
        
    def forward(self, x):
        x = self.func(x)
        return x * self.scale + self.shift
    
    def __str__(self):
        if self.name:
            return "Adjusted " + self.name + f'({self.nf})'
        return "Adjusted " + self.func.__str__() + f'({self.nf})'
    
    def __repr__(self):
        return self.__str__()

In [None]:
#export
def tanSigmoid(x:torch.Tensor)->torch.Tensor: return torch.sigmoid(x) * 2 - 1

In [None]:
#export
def gaussian(x:torch.Tensor)->torch.Tensor: return torch.exp(-x)

In [None]:
#export
TanHNorm = partial(AdjustNormFunc, func=torch.tanh, name='TanH')
TanSigmoidNorm = partial(AdjustNormFunc, func=tanSigmoid, name='TanSigmoid')
GaussianNorm = partial(AdjustNormFunc, func=gaussian, name='Gaussian')

In [None]:
a = TanHNorm(3)
b = TanSigmoidNorm(5)
c = GaussianNorm(6)

In [None]:
print_all(a, b, c)

Adjusted TanH(3)
Adjusted TanSigmoid(5)
Adjusted Gaussian(6)


# Modifying Existing Modules

In [None]:
#export
def recursive_getattr(obj:nn.Module, name:str):
    """ getattr for nested attributes with `.` in their names """
    sequence = name.split('.')
    if len(sequence) == 0:
        return obj
    for attr in sequence[:-1]:
        obj = getattr(obj, attr)
    return getattr(obj, sequence[-1])

def recursive_setattr(obj:nn.Module, name:str, new_attr):
    """ setattr for nested attributes with `.` in their names """
    sequence = name.split('.')
    for attr in sequence[:-1]:
        obj = getattr(obj, attr)
    setattr(obj, sequence[-1], new_attr)

def recreate_network(m:nn.Module, replace_func:Callable, condition:Callable=None)->nn.Module:
    """ modifies `m` by replacing each module that satisfies `condition` with replace_func(module) """
    if condition is None:
        condition = (lambda x: not x == replace_func(x))
    modules = list(m.named_modules())
    if len(modules) == 1:
        return replace_func(m)
    for name, module in modules:
        if condition(module):
            new = replace_func(module)
            recursive_setattr(m, name, new)
    return m

In [None]:
m = models.resnet18(pretrained=False)
recreate_network(m, bn_to_tanh)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): Adjusted TanH(64)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): Adjusted TanH(64)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): Adjusted TanH(64)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): Adjusted TanH(64)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): Adjusted TanH(64)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias

In [None]:
x = torch.randn(16, 3, 224, 224)
m(x).shape

torch.Size([16, 1000])

In [None]:
m = nn.BatchNorm2d(19)
recreate_network(m, bn_to_tanh)

Adjusted TanH(19)

In [None]:
m = simple_cnn([3, 8, 16, 64, 128, 16, 4], bn=True)
recreate_network(m, bn_to_tanh)

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): ReLU(inplace)
    (2): Adjusted TanH(8)
  )
  (1): Sequential(
    (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): ReLU(inplace)
    (2): Adjusted TanH(16)
  )
  (2): Sequential(
    (0): Conv2d(16, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): ReLU(inplace)
    (2): Adjusted TanH(64)
  )
  (3): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): ReLU(inplace)
    (2): Adjusted TanH(128)
  )
  (4): Sequential(
    (0): Conv2d(128, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): ReLU(inplace)
    (2): Adjusted TanH(16)
  )
  (5): Sequential(
    (0): Conv2d(16, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace)
  )
  (6): Sequential(
    (0): AdaptiveAvgPool2d(output_size=1)
    (1):

In [None]:
x = torch.randn(16, 3, 30, 30)
m(x).shape

torch.Size([16, 4])