In [1]:
import torch
import os
import torchvision.models.resnet as resnet
from typing import Type, Any, Callable, Union, List, Optional
import torch.nn as nn
from torch import Tensor

In [2]:
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
pretrained_weights = os.path.join(os.getcwd(), 'inputs', 'resnet.pth')
checkpoint_key = 'teacher'
patch_size = 8

In [3]:
class ResnetBackbone(resnet.ResNet):
    def __init__(
            self,
            block: Type[Union[resnet.BasicBlock, resnet.Bottleneck]],
            layers: List[int],
            num_classes: int = 1000,
            zero_init_residual: bool = False,
            groups: int = 1,
            width_per_group: int = 64,
            replace_stride_with_dilation: Optional[List[bool]] = None,
            norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        resnet.ResNet.__init__(self, block, layers, num_classes, zero_init_residual, groups, width_per_group,
                               replace_stride_with_dilation, norm_layer)
        self.num_channels = 2048

    def _forward_impl(self, x: Tensor) -> Tensor:
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x


In [4]:
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResnetBackbone:
    return _resnet("resnet50", resnet.Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)

In [5]:
def _resnet(
    arch: str,
    block: Type[Union[resnet.BasicBlock, resnet.Bottleneck]],
    layers: List[int],
    pretrained: bool,
    progress: bool,
    **kwargs: Any,
) -> ResnetBackbone:
    model = ResnetBackbone(block, layers, **kwargs)
    return model

In [6]:
backbone = resnet50()
backbone.to(DEVICE)
if os.path.isfile(pretrained_weights):
    state_dict = torch.load(pretrained_weights, map_location="cpu")
    if (
        checkpoint_key is not None
        and checkpoint_key in state_dict
    ):
        print(
            f"Take key {checkpoint_key} in provided checkpoint dict"
        )
        state_dict = state_dict[checkpoint_key]


    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    # remove `backbone.` prefix induced by multicrop wrapper
    state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
    msg = backbone.load_state_dict(state_dict, strict=False)
    print(
        "Pretrained weights found at {} and loaded with msg: {}".format(
            pretrained_weights, msg
        )
    )
else:
    print("error!!!!!!!!!!!!!!")

Take key teacher in provided checkpoint dict
Pretrained weights found at C:\Users\YifengKou\OneDrive - nyu.edu\Desktop\dino\inputs\resnet.pth and loaded with msg: _IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])


In [7]:
torch.save(backbone.state_dict(), os.path.join("zoo", "resnet-full-nofc.pth"))
torch.save(backbone, os.path.join("zoo", "resnet-full-nofc"))