In [1]:
import os
os.chdir('..')

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

In [19]:
from scood.networks.resnet18 import ResNet18 as ResNet18OG
from scood.networks.wrn import WideResNet
from scood.networks import ResNet18, WideResNet
from scood.utils import count_net_params
import torchvision

In [14]:
og_resnet = ResNet18OG(dim_aux=100)
resnet = ResNet18(dim_aux=100)
tv_resnet = torchvision.models.resnet18()

In [15]:
print(count_net_params(og_resnet))
print(count_net_params(resnet))
print(count_net_params(tv_resnet))

11225262
11232942
11689512


In [31]:
timm.list_models("*wide*")

['wide_resnet50_2', 'wide_resnet101_2']

In [32]:
wrn_28_10 = WideResNet(28, widen_factor=10)
wrn_40_4 = WideResNet(40, widen_factor=4)
wrn_50 = timm.create_model('wide_resnet50_2')

In [33]:
print(count_net_params(wrn_28_10))
print(count_net_params(wrn_40_4))
print(count_net_params(wrn_50))

36479194
8949210
68883240


In [29]:
wrn_28_10

WideResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): NetworkBlock(
    (layer): Sequential(
      (0): BasicBlock(
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(16, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (convShortcut): Conv2d(16, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): BasicBlock(
        (bn1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNor

In [30]:
wrn_40_4

WideResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): NetworkBlock(
    (layer): Sequential(
      (0): BasicBlock(
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (convShortcut): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): BasicBlock(
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, 

In [6]:
output = wideresnet(torch.randn(2, 3, 32, 32), return_feature=True)
[o.shape for o in output]

torch.Size([2, 640, 8, 8])


[torch.Size([2, 10]), torch.Size([2, 640])]

In [8]:
resnet = ResNet18()

In [11]:
output = resnet(torch.randn(2, 3, 32, 32), return_feature=True)
[o.shape for o in output]

torch.Size([2, 512, 4, 4])


[torch.Size([2, 10]), torch.Size([2, 512])]

In [24]:
model = timm.create_model('resnet18', output_stride=8, pretrained=True)

In [25]:
model.fc.in_features

512

In [14]:
o = model.forward_features(torch.randn(2, 3, 32, 32))
o.shape

torch.Size([2, 512, 4, 4])

In [16]:
o = model(torch.randn(2, 3, 32, 32))
o.shape

torch.Size([2, 1000])

In [26]:
model.reset_classifier(0)

In [28]:
o = model.forward_features(torch.randn(2, 3, 32, 32))
o.shape

torch.Size([2, 512, 4, 4])

In [27]:
o = model(torch.randn(2, 3, 32, 32))
o.shape

torch.Size([2, 512])

In [21]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  

In [22]:
config = resolve_data_config({}, model=model)

In [23]:
config

{'input_size': (3, 224, 224),
 'interpolation': 'bilinear',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'crop_pct': 0.875}

In [9]:
transform = create_transform(**config)

