In [1]:
%load_ext autoreload

%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Resnet

In [2]:
import torch
from nn.models.classification.resnet import ResNet
from torch import nn
from nn.models.classification.senet import SENetBasicBlock

### Create a default model

Different ResNet with different sizes can be created using the classmethod inside `ResNet`

In [12]:
ResNet.resnet18()
ResNet.resnet34()
ResNet.resnet50()
ResNet.resnet101()
ResNet.resnet152()

KeyboardInterrupt: 

### Customization

You can easily customize your resnet

In [16]:
# change activation
ResNet.resnet18(activation = nn.SELU)
# change number of classes (default is 1000 )
ResNet.resnet18(n_classes=100)
# pass a different block
ResNet.resnet18(block=SENetBasicBlock)
# change the initial convolution
model = ResNet.resnet18()
model.encoder.gate.conv1 = nn.Conv2d(3, 64, kernel_size=3)
# store each feature
x = torch.rand((1, 3, 224, 224))
model = ResNet.resnet18()
features = []
x = model.encoder.gate(x)
for block in model.encoder.blocks:
    x = block(x)
    features.append(x)
    
print([x.shape for x in features])
# [torch.Size([1, 64, 56, 56]), torch.Size([1, 128, 28, 28]), torch.Size([1, 256, 14, 14]), torch.Size([1, 512, 7, 7])]

[torch.Size([1, 64, 56, 56]), torch.Size([1, 128, 28, 28]), torch.Size([1, 256, 14, 14]), torch.Size([1, 512, 7, 7])]


## DenseNet

In [9]:
from nn.models.classification.densenet import DenseNet

# change activation
DenseNet.densenet121(activation = nn.SELU)
# change number of classes (default is 1000 )
DenseNet.densenet121(n_classes=100)
# pass a different block
# DenseNet.densenet121(block=...)
# change the initial convolution
model = DenseNet.densenet121()
model.encoder.gate.conv1 = nn.Conv2d(3, 64, kernel_size=3)
# store each feature
x = torch.rand((1, 3, 224, 224))
model = DenseNet.densenet121()
features = []
x = model.encoder.gate(x)
for block in model.encoder.blocks:
    x = block(x)
    features.append(x)
print([x.shape for x in features])
# [torch.Size([1, 128, 28, 28]), torch.Size([1, 256, 14, 14]), torch.Size([1, 512, 7, 7]), torch.Size([1, 1024, 7, 7])]


[torch.Size([1, 128, 28, 28]), torch.Size([1, 256, 14, 14]), torch.Size([1, 512, 7, 7]), torch.Size([1, 1024, 7, 7])]
