In [1]:
import torchvision

  from .autonotebook import tqdm as notebook_tqdm


In [37]:
model = torchvision.models.resnet50(weights="IMAGENET1K_V2")

In [38]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [39]:
model.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-7.4457e-03, -3.1783e-03,  3.7353e-02,  ...,  4.7936e-02,
                         -2.0325e-02,  8.8140e-03],
                        [-5.7435e-02,  4.4709e-02,  7.7509e-02,  ...,  8.8442e-02,
                          2.9346e-02, -5.8331e-02],
                        [ 6.8356e-02, -2.7044e-01,  4.0348e-01,  ..., -1.6491e-01,
                          2.1868e-01, -7.2909e-02],
                        ...,
                        [-1.0874e-01,  3.8148e-01, -4.5487e-01,  ...,  6.8366e-01,
                         -5.7855e-01,  2.2461e-01],
                        [ 2.5698e-02, -1.7703e-01,  6.4375e-01,  ...,  5.2644e-01,
                         -4.9317e-02, -6.8082e-02],
                        [ 4.5281e-02, -1.3072e-01,  1.7864e-02,  ..., -3.5753e-01,
                          1.8976e-01, -2.2302e-02]],
              
                       [[ 8.9197e-03,  4.8768e-03, -1.5356e-02,  ...,  8.6949e-02,
                         -6.5541

In [40]:
len(model.state_dict().keys())

320

### custom model

In [5]:
import os
os.chdir("/home/zhuoyan/vision/branch_embedding/")
import argparse

import yaml
import torch
import torch.nn as nn
# from torch.utils.tensorboard import SummaryWriter

from libs.core import load_config
from libs.datasets import make_dataset, make_data_loader
from libs.model import Worker
from libs.utils import *

import argparse

def parse_args(input_args):
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, help='config file path')
    parser.add_argument('-n', '--name', type=str, help='job name')
    parser.add_argument('-g', '--gpu', type=str, default='0', help='GPU IDs')
    parser.add_argument('-pf', '--print_freq', type=int, default=1, help='print frequency (x100 itrs)')
    
    args = parser.parse_args(input_args)
    args.print_freq *= 100
    
    return args

# Example
input_args = ['-c', '/home/zhuoyan/vision/branch_embedding/configs/resnet50_imagenet.yaml', 
              '-n', 'train_resnet18_cifar10', 
              '-pf', '2']
args = parse_args(input_args)

print(args)

# set up checkpoint folder
os.makedirs('log', exist_ok=True)
ckpt_path = os.path.join('log', args.name)
ensure_path(ckpt_path)

check_file(args.config)
cfg = load_config(args.config)
print('config loaded from command line')
print("begin")

worker = Worker(cfg['model'])
yaml.dump(cfg, open(os.path.join(ckpt_path, 'config.yaml'), 'w'))

Namespace(config='/home/zhuoyan/vision/branch_embedding/configs/resnet50_imagenet.yaml', name='train_resnet18_cifar10', gpu='0', print_freq=200)
config loaded from command line
begin


In [6]:
myresnet = worker.resnet

In [7]:
myresnet.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[ 8.3249e-03,  2.8562e-02,  1.9534e-02,  ..., -4.8804e-02,
                         -9.6406e-03,  1.0822e-02],
                        [-3.6442e-02,  7.3035e-02, -6.1080e-04,  ..., -1.0023e-01,
                         -6.7153e-02,  3.0907e-02],
                        [-6.2699e-03,  1.1590e-01,  9.7606e-02,  ..., -1.9845e-02,
                         -1.8221e-03,  9.7530e-02],
                        ...,
                        [ 1.6269e-02,  2.7184e-01,  5.4999e-01,  ...,  5.5690e-01,
                          6.5946e-01,  5.1294e-01],
                        [-1.4853e-02,  1.1854e-01,  2.1143e-01,  ...,  3.3957e-01,
                          3.8925e-01,  3.1459e-01],
                        [-4.7338e-02,  1.3811e-02, -4.5144e-02,  ..., -9.2893e-02,
                         -1.6902e-02,  1.0397e-01]],
              
                       [[ 3.8145e-02,  1.0174e-01, -1.0502e-02,  ..., -9.1463e-02,
                         -1.8363

In [8]:
com = []
for (key,val),(mykey,myval) in zip(model.state_dict().items(), myresnet.state_dict().items()):
    com.append(key == mykey)
sum(com)

320

In [9]:
len(model.state_dict().items())

320

In [10]:
len(myresnet.state_dict().items())

320

In [11]:
from torchvision.models import get_weight

In [41]:
weight = get_weight("ResNet50_Weights.IMAGENET1K_V2")
weight

ResNet50_Weights.IMAGENET1K_V2

In [42]:
type(weight)

<enum 'ResNet50_Weights'>

In [43]:
weight_state_dict = weight.get_state_dict()
weight_state_dict

OrderedDict([('conv1.weight',
              tensor([[[[-7.4457e-03, -3.1783e-03,  3.7353e-02,  ...,  4.7936e-02,
                         -2.0325e-02,  8.8140e-03],
                        [-5.7435e-02,  4.4709e-02,  7.7509e-02,  ...,  8.8442e-02,
                          2.9346e-02, -5.8331e-02],
                        [ 6.8356e-02, -2.7044e-01,  4.0348e-01,  ..., -1.6491e-01,
                          2.1868e-01, -7.2909e-02],
                        ...,
                        [-1.0874e-01,  3.8148e-01, -4.5487e-01,  ...,  6.8366e-01,
                         -5.7855e-01,  2.2461e-01],
                        [ 2.5698e-02, -1.7703e-01,  6.4375e-01,  ...,  5.2644e-01,
                         -4.9317e-02, -6.8082e-02],
                        [ 4.5281e-02, -1.3072e-01,  1.7864e-02,  ..., -3.5753e-01,
                          1.8976e-01, -2.2302e-02]],
              
                       [[ 8.9197e-03,  4.8768e-03, -1.5356e-02,  ...,  8.6949e-02,
                         -6.5541

In [44]:
len(weight_state_dict.keys())

320

In [45]:
com = []
for (key,val),(mykey,myval) in zip(model.state_dict().items(), weight_state_dict.items()):
    com.append(key == mykey)
sum(com)

320