In [1]:
import torch
from torch import nn
import os

try:
    from models.network import AutoEncoder
except ModuleNotFoundError:
    # Append base path with all needed code
    import pathlib
    import sys
    sys.path.append('../')
    # Try again
    from models.network import AutoEncoder

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
def process_state_dict(network_state_dict, type = 0):

    if torch.cuda.device_count() >= 2 and type == 0:
        for key, item in list(network_state_dict.items()):
            if key[:7] != 'module.':
                new_key = 'module.' + key
                network_state_dict[new_key] = item
                del network_state_dict[key]
    else:
        for key, item in list(network_state_dict.items()):
            if key[:7] == 'module.':
                new_key = key[7:]
                network_state_dict[new_key] = item
                del network_state_dict[key]

    return network_state_dict

In [2]:
torch.__version__

'1.12.0+cu102'

In [3]:
import importlib

## import config here
spec = importlib.util.spec_from_file_location('*', '../configs/config_discrete_class.py')
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)

In [4]:
network = AutoEncoder(config=config)
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    print(f"Use {torch.cuda.device_count()} GPUS!")
    network = nn.DataParallel(network)
network = network.to(device)

Use 4 GPUS!


In [12]:
load_path = '/path/to/res/train_neural_template/' +\
    '2023-07-01_11-31-15_IM-Net-Training-experiment-home-AutoEncoder-3DCNN-Flow/model_epoch_2_325.pth'

network_state_dict = torch.load(load_path)

network_state_dict = process_state_dict(network_state_dict)
network_state_dict.keys()

odict_keys(['module.encoder.conv_1.weight', 'module.encoder.conv_1.bias', 'module.encoder.conv_2.weight', 'module.encoder.conv_2.bias', 'module.encoder.conv_3.weight', 'module.encoder.conv_3.bias', 'module.encoder.conv_4.weight', 'module.encoder.conv_4.bias', 'module.encoder.conv_5.weight', 'module.encoder.conv_5.bias', 'module.decoder.mean_parms', 'module.decoder.sqrt_end_time', 'module.decoder.ode_layers.0.ode_net.context_linear.weight', 'module.decoder.ode_layers.0.ode_net.context_linear.bias', 'module.decoder.ode_layers.0.ode_net.coordinate_linear.weight', 'module.decoder.ode_layers.0.ode_net.coordinate_linear.bias', 'module.decoder.ode_layers.0.ode_net.last_linear.weight', 'module.decoder.ode_layers.0.ode_net.last_linear.bias', 'module.decoder.ode_layers.0.ode_net.layers.0.weight', 'module.decoder.ode_layers.0.ode_net.layers.0.bias', 'module.decoder.ode_layers.0.ode_net.layers.1.weight', 'module.decoder.ode_layers.0.ode_net.layers.1.bias', 'module.decoder.linear_layer.weight', 'mo

In [9]:
network.load_state_dict(network_state_dict)

RuntimeError: Error(s) in loading state_dict for DataParallel:
	Missing key(s) in state_dict: "module.encoder.conv_1.weight", "module.encoder.conv_1.bias", "module.encoder.conv_2.weight", "module.encoder.conv_2.bias", "module.encoder.conv_3.weight", "module.encoder.conv_3.bias", "module.encoder.conv_4.weight", "module.encoder.conv_4.bias", "module.encoder.conv_5.weight", "module.encoder.conv_5.bias", "module.decoder.mean_parms", "module.decoder.sqrt_end_time", "module.decoder.ode_layers.0.ode_net.context_linear.weight", "module.decoder.ode_layers.0.ode_net.context_linear.bias", "module.decoder.ode_layers.0.ode_net.coordinate_linear.weight", "module.decoder.ode_layers.0.ode_net.coordinate_linear.bias", "module.decoder.ode_layers.0.ode_net.last_linear.weight", "module.decoder.ode_layers.0.ode_net.last_linear.bias", "module.decoder.ode_layers.0.ode_net.layers.0.weight", "module.decoder.ode_layers.0.ode_net.layers.0.bias", "module.decoder.ode_layers.0.ode_net.layers.1.weight", "module.decoder.ode_layers.0.ode_net.layers.1.bias", "module.decoder.linear_layer.weight", "module.decoder.linear_layer.bias", "module.decoder.bsp_field.concave_layer_weights", "module.decoder.bsp_field.convex_layer_weights", "module.decoder.bsp_field.plane_encoder.linear_1.weight", "module.decoder.bsp_field.plane_encoder.linear_1.bias", "module.decoder.bsp_field.plane_encoder.linear_2.weight", "module.decoder.bsp_field.plane_encoder.linear_2.bias", "module.decoder.bsp_field.plane_encoder.linear_3.weight", "module.decoder.bsp_field.plane_encoder.linear_3.bias", "module.decoder.bsp_field.plane_encoder.linear_4.weight", "module.decoder.bsp_field.plane_encoder.linear_4.bias", "module.decoder.recognition_decoder.linear_1.weight", "module.decoder.recognition_decoder.linear_1.bias", "module.decoder.recognition_decoder.linear_2.weight", "module.decoder.recognition_decoder.linear_2.bias", "module.decoder.recognition_decoder.linear_3.weight", "module.decoder.recognition_decoder.linear_3.bias", "module.decoder.recognition_decoder.linear_4.weight", "module.decoder.recognition_decoder.linear_4.bias". 
	Unexpected key(s) in state_dict: "module.state", "module.param_groups". 

In [12]:
torch.save(network.state_dict(), './test_save.pth')

In [20]:
network_state_dict['module.param_groups']

[{'lr': 5e-05,
  'betas': (0.5, 0.999),
  'eps': 1e-08,
  'weight_decay': 0,
  'amsgrad': False,
  'maximize': False,
  'foreach': None,
  'capturable': False,
  'differentiable': False,
  'fused': False,
  'params': [0,
   1,
   2,
   3,
   4,
   5,
   6,
   7,
   8,
   9,
   10,
   11,
   12,
   13,
   14,
   15,
   16,
   17,
   18,
   19,
   20,
   21,
   22,
   23,
   24,
   25,
   26,
   27,
   28,
   29,
   30,
   31,
   32,
   33,
   34,
   35,
   36,
   37,
   38,
   39,
   40,
   41]}]

In [6]:
import numpy as np

arr = np.random.randn(2, 24, 120, 120, 3)
arr1 = np.random.randn(2, 24, 120, 120, 3)

In [7]:
arr[:, :, :, :, None].shape

(2, 24, 120, 120, 1, 3)

In [1]:
import torch
import torch.nn.functional as F

import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [200]:
bn = torch.nn.BatchNorm2d(4).cpu()

In [201]:
bn.weight.requires_grad = False

In [202]:
t_z = torch.from_numpy(np.ones( (1, 4, 4, 4) ) * 100).float()
t = torch.rand(1, 4, 4, 4, requires_grad=True).float()

In [203]:
t

tensor([[[[0.1881, 0.7187, 0.2104, 0.2221],
          [0.8252, 0.0773, 0.6189, 0.8774],
          [0.3207, 0.4527, 0.3888, 0.3276],
          [0.1700, 0.4823, 0.5137, 0.7986]],

         [[0.6632, 0.6943, 0.5005, 0.3868],
          [0.7616, 0.6865, 0.5945, 0.0834],
          [0.9453, 0.3598, 0.7002, 0.9613],
          [0.4366, 0.6929, 0.3710, 0.4919]],

         [[0.4835, 0.3189, 0.7357, 0.3430],
          [0.4408, 0.6465, 0.2445, 0.3672],
          [0.6501, 0.6324, 0.9951, 0.3063],
          [0.4325, 0.2988, 0.1153, 0.2191]],

         [[0.8946, 0.0480, 0.9874, 0.5353],
          [0.3821, 0.0091, 0.1255, 0.7976],
          [0.9778, 0.4896, 0.0179, 0.9502],
          [0.1056, 0.5423, 0.9121, 0.3488]]]], requires_grad=True)

In [204]:
optimizer = torch.optim.Adam(params=[t, bn.weight, bn.bias], lr=0.3)

In [221]:
optimizer.zero_grad()
loss = ((t_z - bn(t)) ** 2).mean()
loss.backward()
optimizer.step()

In [222]:
bn.weight, bn.bias

(Parameter containing:
 tensor([1., 1., 1., 1.]),
 Parameter containing:
 tensor([2.6978, 2.6978, 2.6978, 2.6978], requires_grad=True))