In [2]:
from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model(
    'densenet201.tv_in1k',
    pretrained=True,
    num_classes=0,  # remove classifier nn.Linear
)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm .data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # output is (batch_size, num_features) shaped tensor

# or equivalently (without needing to set num_classes=0)

output = model.forward_features(transforms(img).unsqueeze(0))
# output is unpooled, a (1, 1920, 7, 7) shaped tensor

output = model.forward_head(output, pre_logits=True)
# output is a (1, num_features) shaped tensor

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
print(img.size)
print(output.shape)

(512, 512)
torch.Size([1, 1920])


In [7]:
img_input = transforms(img).unsqueeze(0)
print(img_input.shape)
img_output = model.forward_features(img_input)
print(img_output.shape)

torch.Size([1, 3, 224, 224])
torch.Size([1, 1920, 7, 7])


In [8]:
import torch
import torch.nn as nn
from collections import OrderedDict

class DenseNetWithoutPooling(nn.Module):
    def __init__(
            self,
            growth_rate=32,
            block_config=(6, 12, 24, 16),
            num_classes=1000,
            in_chans=3,
            global_pool=None,  # Set to None to remove global pooling
            bn_size=4,
            stem_type='',
            act_layer='relu',
            norm_layer='batchnorm2d',
            aa_layer=None,
            drop_rate=0.,
            proj_drop_rate=0.,
            memory_efficient=False,
            aa_stem_only=True,
    ):
        super(DenseNetWithoutPooling, self).__init__()
        self.num_classes = num_classes
        norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)

        # Stem
        deep_stem = 'deep' in stem_type  # 3x3 deep stem
        num_init_features = growth_rate * 2
        if aa_layer is None:
            # Replace pooling with Identity layer (no-op)
            stem_pool = nn.Identity()
        else:
            stem_pool = nn.Sequential(*[
                nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
                aa_layer(channels=num_init_features, stride=2)])
        if deep_stem:
            stem_chs_1 = stem_chs_2 = growth_rate
            if 'tiered' in stem_type:
                stem_chs_1 = 3 * (growth_rate // 4)
                stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4)
            self.features = nn.Sequential(OrderedDict([
                ('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)),
                ('norm0', norm_layer(stem_chs_1)),
                ('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)),
                ('norm1', norm_layer(stem_chs_2)),
                ('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)),
                ('norm2', norm_layer(num_init_features)),
                ('pool0', stem_pool),  # Replaced pool with Identity layer
            ]))
        else:
            self.features = nn.Sequential(OrderedDict([
                ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                ('norm0', norm_layer(num_init_features)),
                ('pool0', stem_pool),  # Replaced pool with Identity layer
            ]))
        self.feature_info = [
            dict(num_chs=num_init_features, reduction=2, module=f'features.norm{2 if deep_stem else 0}')]
        current_stride = 4

        # DenseBlocks
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                norm_layer=norm_layer,
                drop_rate=proj_drop_rate,
                grad_checkpointing=memory_efficient,
            )
            module_name = f'denseblock{(i + 1)}'
            self.features.add_module(module_name, block)
            num_features = num_features + num_layers * growth_rate
            transition_aa_layer = None if aa_stem_only else aa_layer
            if i != len(block_config) - 1:
                self.feature_info += [
                    dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)]
                current_stride *= 2
                trans = DenseTransition(
                    num_input_features=num_features,
                    num_output_features=num_features // 2,
                    norm_layer=norm_layer,
                    aa_layer=transition_aa_layer,
                )
                self.features.add_module(f'transition{i + 1}', trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', norm_layer(num_features))

        self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')]
        self.num_features = self.head_hidden_size = num_features

        # Global Pooling & Classifier removed
        self.global_pool = nn.Identity()  # No pooling
        self.head_drop = nn.Dropout(drop_rate)
        self.classifier = nn.Identity()  # No classifier for feature extraction

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Forward pass without pooling
        x = self.features(x)
        x = self.global_pool(x)  # Identity, no global pooling
        x = self.head_drop(x)
        x = self.classifier(x)  # Identity, no classifier
        return x
