In [4]:
# THIS NOTEBOOK IS FOR LOADING THE PRETRAINED EM 300micron-by-300micron model FROM MATLAB

# ALWAYS RUN THIS CELL
# from AntennaNetwork import AntennaCNN#imports for next cells 
import numpy as np
import h5py
import torch
from tqdm.auto import tqdm
import torch.nn as nn
from EM300CNN import EM300CNN

In [5]:
# RUN THIS CELL TO IMPORT MODEL PARAMS FROM MATLAB (not necessary if already run)
model = EM300CNN()

# Load matlab data as file
f = h5py.File('em300_params.mat','r')

# C is a matlab variable cell array storing the weights and biases of conv and forward layers
weights_and_biases = f.get('C')

# load batchnorm params from matlab file
bn_params = f.get("BN_Params") 

CONV_LAYERS = 12  # count of conv layers in model
FC_LAYERS = 5    # count of fc layers

# print(weights_and_biases)

# This loop sets up the convolutional layer parameters of the model, weights and biases
i = 0
for name, params in list(model.named_parameters()):
    if "conv" not in name:
        continue

    # Set convolutional layer params
    # for every two entries, first is weights, second is biases
    if i % 2 == 0:
        layers_weights = torch.tensor(np.array(f[weights_and_biases[i,0]]))
        layers_weights = torch.transpose(layers_weights, 3, 2) # transposing seems correct
        # print(layers_weights.shape)
    else:
        layers_weights = torch.squeeze(torch.tensor(np.array(f[weights_and_biases[i,0]])))
        # print(layers_weights.shape)

    with torch.no_grad():
        params.data = nn.parameter.Parameter(layers_weights)
    
    i+=1

# This loop sets up the fully connected layer parameters of the model, weights and biases
i = CONV_LAYERS*2
for name, params in list(model.named_parameters()):
    if "fc_" not in name:
        continue

    # Set forward layer params, again, every two entries, first is weights
    # transpose because torch Linear stores weights this way 
    if i % 2 == 0:
        layers_weights = torch.tensor(np.array(f[weights_and_biases[i,0]])).T
        # print(layers_weights.shape)
    else:
        layers_weights = torch.squeeze(torch.tensor(np.array(f[weights_and_biases[i,0]])))
        # print(layers_weights.shape)

    with torch.no_grad():
        params.data = nn.parameter.Parameter(layers_weights)

    i+=1

# torch.save(model.state_dict(), "./saved/AntennaCNN")

# This loop sets up the batch normalization layer parameters
i = 0
for name, m in model.named_children():
    if 'batchnorm' not in name:
        continue
    # mean = self.running_mean
    # variance = self.running_var
    # gamma = self.weight
    # beta = self.bias
    # batchnorm_params[i] will be four separate arrays, in the order:
    # running_mean, running_var, weight/gamme, beta/bias
    # print(name)
    running_mean = torch.tensor(np.array(f[f[bn_params[i,0]][0,0]])).squeeze()
    running_var = torch.tensor(np.array(f[f[bn_params[i,0]][0,1]])).squeeze()
    weight = torch.tensor(np.array(f[f[bn_params[i,0]][0,2]])).squeeze()
    bias = torch.tensor(np.array(f[f[bn_params[i,0]][0,3]])).squeeze()

    with torch.no_grad():
        m.running_mean = nn.parameter.Parameter(running_mean, requires_grad=False)
        m.running_var = nn.parameter.Parameter(running_var, requires_grad=False)
        m.weight = nn.parameter.Parameter(weight)
        m.bias = nn.parameter.Parameter(bias)

    i+=1

torch.save(model.state_dict(), "./saved/EM300CNN")

In [6]:
# RUN THIS TO LOAD THE MODEL AND TEST IT
import matplotlib.pyplot as plt


model = EM300CNN()
model.load_state_dict(torch.load("./saved/EM300CNN"))


<All keys matched successfully>