In [1]:
import torchvision

  from .autonotebook import tqdm as notebook_tqdm


In [37]:
model = torchvision.models.resnet50(weights="IMAGENET1K_V2")

In [38]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [39]:
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-7.4457e-03, -3.1783e-03,  3.7353e-02,  ...,  4.7936e-02,
                         -2.0325e-02,  8.8140e-03],
                        [-5.7435e-02,  4.4709e-02,  7.7509e-02,  ...,  8.8442e-02,
                          2.9346e-02, -5.8331e-02],
                        [ 6.8356e-02, -2.7044e-01,  4.0348e-01,  ..., -1.6491e-01,
                          2.1868e-01, -7.2909e-02],
                        ...,
                        [-1.0874e-01,  3.8148e-01, -4.5487e-01,  ...,  6.8366e-01,
                         -5.7855e-01,  2.2461e-01],
                        [ 2.5698e-02, -1.7703e-01,  6.4375e-01,  ...,  5.2644e-01,
                         -4.9317e-02, -6.8082e-02],
                        [ 4.5281e-02, -1.3072e-01,  1.7864e-02,  ..., -3.5753e-01,
                          1.8976e-01, -2.2302e-02]],
              
                       [[ 8.9197e-03,  4.8768e-03, -1.5356e-02,  ...,  8.6949e-02,
                         -6.5541

In [40]:
len(model.state_dict().keys())

320

### custom model

In [26]:
import os
os.chdir("/Users/zyxu/Documents/py/vision/adaptive_inference/")
import argparse

import yaml
import torch
import torch.nn as nn
# from torch.utils.tensorboard import SummaryWriter

from libs.core import load_config
from libs.datasets import make_dataset, make_data_loader
from libs.model import Worker
from libs.utils import *

import argparse

def parse_args(input_args):
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, help='config file path')
    parser.add_argument('-n', '--name', type=str, help='job name')
    parser.add_argument('-g', '--gpu', type=str, default='0', help='GPU IDs')
    parser.add_argument('-pf', '--print_freq', type=int, default=1, help='print frequency (x100 itrs)')
    
    args = parser.parse_args(input_args)
    args.print_freq *= 100
    
    return args

# Example
input_args = ['-c', '/Users/zyxu/Documents/py/vision/adaptive_inference/configs/resnet50_imagenet.yaml', 
              '-n', 'train_resnet18_cifar10', 
              '-pf', '2']
args = parse_args(input_args)

print(args)

# set up checkpoint folder
os.makedirs('log', exist_ok=True)
ckpt_path = os.path.join('log', args.name)
ensure_path(ckpt_path)

check_file(args.config)
cfg = load_config(args.config)
print('config loaded from command line')
print("begin")

worker = Worker(cfg['model'])
yaml.dump(cfg, open(os.path.join(ckpt_path, 'config.yaml'), 'w'))

Namespace(config='/Users/zyxu/Documents/py/vision/adaptive_inference/configs/resnet50_imagenet.yaml', name='train_resnet18_cifar10', gpu='0', print_freq=200)
path exists! path:  log/train_resnet18_cifar10
config loaded from command line
begin


In [28]:
cfg['model']

{'resnet': {'arch': 'resnet50', 'dataset': 'imagenet'},
 'branch_enc': {'embd_dim': 256,
  'out_dim': 128,
  'n_heads': 4,
  'n_layers': 5,
  'attn_pdrop': 0.1,
  'proj_pdrop': 0.1,
  'path_pdrop': 0.1,
  'eos': False,
  'embd_type': 0,
  'pe_type': 0,
  'seq_len': 15},
 'content_enc': {'out_dim': 128,
  'arch': 'resnet10_imagenet',
  'pretrained': False},
 'branch_vae': {'hid_dim': 64, 'n_layers': 3, 'latent_dim': 4, 'in_dim': 15}}

In [6]:
myresnet = worker.resnet

In [4]:
myresnet.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[ 8.3249e-03,  2.8562e-02,  1.9534e-02,  ..., -4.8804e-02,
                         -9.6406e-03,  1.0822e-02],
                        [-3.6442e-02,  7.3035e-02, -6.1080e-04,  ..., -1.0023e-01,
                         -6.7153e-02,  3.0907e-02],
                        [-6.2699e-03,  1.1590e-01,  9.7606e-02,  ..., -1.9845e-02,
                         -1.8221e-03,  9.7530e-02],
                        ...,
                        [ 1.6269e-02,  2.7184e-01,  5.4999e-01,  ...,  5.5690e-01,
                          6.5946e-01,  5.1294e-01],
                        [-1.4853e-02,  1.1854e-01,  2.1143e-01,  ...,  3.3957e-01,
                          3.8925e-01,  3.1459e-01],
                        [-4.7338e-02,  1.3811e-02, -4.5144e-02,  ..., -9.2893e-02,
                         -1.6902e-02,  1.0397e-01]],
              
                       [[ 3.8145e-02,  1.0174e-01, -1.0502e-02,  ..., -9.1463e-02,
                         -1.8363

In [5]:
com = []
for (key,val),(mykey,myval) in zip(model.state_dict().items(), myresnet.state_dict().items()):
    com.append(key == mykey)
sum(com)

NameError: name 'model' is not defined

In [9]:
len(model.state_dict().items())

320

In [10]:
len(myresnet.state_dict().items())

320

In [1]:
from torchvision.models import get_weight

In [2]:
weight = get_weight("ResNet50_Weights.IMAGENET1K_V2")
weight

ResNet50_Weights.IMAGENET1K_V2

In [3]:
type(weight)

<enum 'ResNet50_Weights'>

In [4]:
weight_state_dict = weight.get_state_dict()
weight_state_dict

TypeError: get_state_dict() missing 1 required positional argument: 'progress'

In [44]:
len(weight_state_dict.keys())

320

In [45]:
com = []
for (key,val),(mykey,myval) in zip(model.state_dict().items(), weight_state_dict.items()):
    com.append(key == mykey)
sum(com)

320

In [20]:
import torch
x = torch.rand((2,3,4))
x

tensor([[[0.0782, 0.3391, 0.9841, 0.0216],
         [0.6169, 0.1615, 0.5126, 0.7672],
         [0.2970, 0.1550, 0.2595, 0.4295]],

        [[0.6133, 0.4709, 0.1723, 0.9655],
         [0.8647, 0.0345, 0.1493, 0.8706],
         [0.8772, 0.9480, 0.7761, 0.7376]]])

In [21]:
keep_prob = 1 - 0.5
shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
x.new_empty(shape).bernoulli_(keep_prob)

tensor([[[1.]],

        [[1.]]])

In [8]:
shape

(2, 1)

### timm

In [3]:
import torch
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
import os
os.chdir("/Users/zyxu/Documents/py/vision/adaptive_inference")
import sys
sys.path.insert(0, "/Users/zyxu/Documents/py/vision/adaptive_inference")
import timm
model_names = timm.list_models('*resnet*50*')

pprint(model_names)

['cspresnet50',
 'cspresnet50d',
 'cspresnet50w',
 'ecaresnet50d',
 'ecaresnet50d_pruned',
 'ecaresnet50t',
 'gcresnet50t',
 'lambda_resnet50ts',
 'legacy_seresnet50',
 'nf_ecaresnet50',
 'nf_resnet50',
 'nf_seresnet50',
 'resnet50',
 'resnet50_gn',
 'resnet50c',
 'resnet50d',
 'resnet50s',
 'resnet50t',
 'resnetaa50',
 'resnetaa50d',
 'resnetblur50',
 'resnetblur50d',
 'resnetrs50',
 'resnetrs350',
 'resnetv2_50',
 'resnetv2_50d',
 'resnetv2_50d_evos',
 'resnetv2_50d_frn',
 'resnetv2_50d_gn',
 'resnetv2_50t',
 'resnetv2_50x1_bit',
 'resnetv2_50x3_bit',
 'seresnet50',
 'seresnet50t',
 'seresnetaa50d',
 'skresnet50',
 'skresnet50d',
 'vit_base_resnet50d_224',
 'vit_small_resnet50d_s16_224',
 'wide_resnet50_2']


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
resnet18 = timm.create_model("resnet18", pretrained=True)

In [8]:
resnet18

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act1): ReLU(inplace=True)
      (aa): Identity()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, m

In [4]:
resnet18.pretrained_cfg

{'url': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
 'num_classes': 1000,
 'input_size': (3, 224, 224),
 'pool_size': (7, 7),
 'crop_pct': 0.875,
 'interpolation': 'bilinear',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'first_conv': 'conv1',
 'classifier': 'fc',
 'architecture': 'resnet18'}

In [18]:
resnet50 = timm.create_model("resnet50.a1_in1k", pretrained=True)

zhuoyan variant: resnet50
------- kwargs: dict_keys(['block', 'layers', 'pretrained_cfg', 'pretrained_cfg_overlay'])
Loading pretrained weights from Hugging Face hub (timm/resnet50.a1_in1k)
zhuoyan: len(state_dict) 320


In [9]:
resnet50.pretrained_cfg

{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth',
 'hf_hub_id': 'timm/resnet50.a1_in1k',
 'architecture': 'resnet50',
 'tag': 'a1_in1k',
 'custom_load': False,
 'input_size': (3, 224, 224),
 'test_input_size': (3, 288, 288),
 'fixed_input_size': False,
 'interpolation': 'bicubic',
 'crop_pct': 0.95,
 'test_crop_pct': 1.0,
 'crop_mode': 'center',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'num_classes': 1000,
 'pool_size': (7, 7),
 'first_conv': 'conv1',
 'classifier': 'fc',
 'origin_url': 'https://github.com/huggingface/pytorch-image-models',
 'paper_ids': 'arXiv:2110.00476'}

In [11]:
len(resnet50.state_dict().items())

320

In [12]:
len(myresnet.state_dict().items())

320

In [13]:
com = []
for (key,val),(mykey,myval) in zip(resnet50.state_dict().items(), myresnet.state_dict().items()):
    com.append(key == mykey)
sum(com)

320

In [16]:
checkpoint = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth'
state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=True)
len(state_dict)

348

In [17]:
state_dict

OrderedDict([('conv1.0.weight',
              tensor([[[[ 1.0862e-01,  1.2507e-01,  8.3548e-02],
                        [-1.7417e-01, -4.7524e-01,  3.1098e-02],
                        [ 2.3888e-02,  1.5853e+00, -1.3207e+00]],
              
                       [[-1.9070e-02,  5.8473e-02,  1.2558e-01],
                        [-2.8809e-01, -7.7538e-01,  3.0600e-01],
                        [-1.0539e-01,  2.7164e+00, -2.0706e+00]],
              
                       [[-1.7092e-02, -1.6253e-02,  1.3529e-01],
                        [-3.0489e-01, -3.9220e-01,  1.5191e-01],
                        [-1.8679e-02,  1.2251e+00, -6.8082e-01]]],
              
              
                      [[[ 2.5032e-01, -2.9530e-01,  8.9573e-02],
                        [-2.2691e-01, -2.2153e-01,  7.7606e-01],
                        [ 3.6369e-02,  7.8641e-01, -1.1467e+00]],
              
                       [[-2.4190e-01,  3.0816e-01, -9.8971e-02],
                        [ 3.0372e-01,  6.15

In [12]:
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
state_dict = load_state_dict_from_hf("timm/resnet50.a1_in1k")
len(state_dict)

320

In [10]:
com = []
for (key,val),(mykey,myval) in zip(state_dict.items(), myresnet.state_dict().items()):
    com.append(key == mykey)
sum(com)

48

In [14]:
from collections import OrderedDict

# Assuming 'state_dict' is a dictionary containing the model state
# and 'myresnet.state_dict()' is another dictionary with the desired key order.

# Create an ordered dictionary with keys in the same order as myresnet.state_dict()
ordered_state_dict = OrderedDict((key, state_dict[key]) for key in myresnet.state_dict().keys() if key in state_dict)

# Now you can compare the ordered_state_dict with myresnet.state_dict()
com = []
for (key, val), (mykey, myval) in zip(ordered_state_dict.items(), myresnet.state_dict().items()):
    com.append(key == mykey)
sum(com)


320

In [15]:
ordered_state_dict

OrderedDict([('conv1.weight',
              tensor([[[[ 8.3249e-03,  2.8562e-02,  1.9534e-02,  ..., -4.8804e-02,
                         -9.6406e-03,  1.0822e-02],
                        [-3.6442e-02,  7.3035e-02, -6.1080e-04,  ..., -1.0023e-01,
                         -6.7153e-02,  3.0907e-02],
                        [-6.2699e-03,  1.1590e-01,  9.7606e-02,  ..., -1.9845e-02,
                         -1.8221e-03,  9.7530e-02],
                        ...,
                        [ 1.6269e-02,  2.7184e-01,  5.4999e-01,  ...,  5.5690e-01,
                          6.5946e-01,  5.1294e-01],
                        [-1.4853e-02,  1.1854e-01,  2.1143e-01,  ...,  3.3957e-01,
                          3.8925e-01,  3.1459e-01],
                        [-4.7338e-02,  1.3811e-02, -4.5144e-02,  ..., -9.2893e-02,
                         -1.6902e-02,  1.0397e-01]],
              
                       [[ 3.8145e-02,  1.0174e-01, -1.0502e-02,  ..., -9.1463e-02,
                         -1.8363

In [19]:
resnet50.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[ 8.3249e-03,  2.8562e-02,  1.9534e-02,  ..., -4.8804e-02,
                         -9.6406e-03,  1.0822e-02],
                        [-3.6442e-02,  7.3035e-02, -6.1080e-04,  ..., -1.0023e-01,
                         -6.7153e-02,  3.0907e-02],
                        [-6.2699e-03,  1.1590e-01,  9.7606e-02,  ..., -1.9845e-02,
                         -1.8221e-03,  9.7530e-02],
                        ...,
                        [ 1.6269e-02,  2.7184e-01,  5.4999e-01,  ...,  5.5690e-01,
                          6.5946e-01,  5.1294e-01],
                        [-1.4853e-02,  1.1854e-01,  2.1143e-01,  ...,  3.3957e-01,
                          3.8925e-01,  3.1459e-01],
                        [-4.7338e-02,  1.3811e-02, -4.5144e-02,  ..., -9.2893e-02,
                         -1.6902e-02,  1.0397e-01]],
              
                       [[ 3.8145e-02,  1.0174e-01, -1.0502e-02,  ..., -9.1463e-02,
                         -1.8363

In [24]:
com = []
for (key, val), (mykey, myval) in zip(ordered_state_dict.items(), resnet50.state_dict().items()):
    # Check if keys and values are the same
    # The values are typically tensors, so we use torch.equal to compare them
    com.append((key == mykey) and torch.equal(val, myval))

# Sum the results to get the number of matches
total_matches = sum(com)

# If total_matches is equal to the length of the state_dict (or myresnet.state_dict()),
# it means all keys and values match.
total_matches

320

In [25]:
myresnet.load_state_dict(ordered_state_dict)

<All keys matched successfully>