# ResNet-18 with MNIST

## Dependencies

In [1]:
# Basic
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
# PyTorch
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

## Configs

In [3]:
# configs
config = {
    "data": {
        "path": "../data",
        "in_channels": 1,
        "num_classes": 10,
        "batch_size": 512
    },
    "model": {
        
    },
    "training":{
        "learning_rate": 1e-5,
        "optimizer": "sgd",
        "epochs": 5
    }
}

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Data

In [5]:
# transform for MNIST 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # specific for MNIST
])

In [6]:
train_dataset = torchvision.datasets.MNIST(
    root = config['data']['path'],
    train = True,
    transform = transform,
    download = False
)

test_dataset = torchvision.datasets.MNIST(
    root = config['data']['path'],
    train = False,
    transform = transform
)

In [7]:
train_loader = torch.utils.data.DataLoader(
    dataset = train_dataset,
    batch_size = config['data']['batch_size'],
    shuffle = True
)
test_loader = torch.utils.data.DataLoader(
    dataset = test_dataset,
    batch_size = config['data']['batch_size'],
    shuffle=False
)

## Model

In [8]:
# class BasicBlock(nn.Module):
#     expansion = 1
    
#     def __init__(self, in_channels, channels, stride=1):
#         super(BasicBlock, self).__init__()
#         self.expansion = 1
#         self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
#         self.bn1 = nn.BatchNorm2d(channels)
#         self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
#         self.bn2 = nn.BatchNorm2d(channels)
        
#         self.shortcut = nn.Sequential()
#         if stride != 1 or in_channels != self.expansion*channels:
#             # modify shortcut for correct channel output
#             self.shortcut = nn.Sequential(
#                 nn.Conv2d(in_channels, self.expansion*channels, kernel_size=1, stride=stride, bias=False),
#                 nn.BatchNorm2d(self.expansion * channels)
#             )
            
#     def forward(self, x):
#         out = nn.ReLU()(self.bn1(self.conv1(x)))
#         out = self.bn2(self.conv2(out))
#         out += self.shortcut(x)
#         out = nn.ReLU()(out)
#         return out
    
# class ResNet(nn.Module):
#     def __init__(self, block, num_blocks, in_channels, num_classes):
#         super(ResNet, self).__init__()
#         self.in_channels = in_channels
        
#         self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
#         self.in_channels = 64 # modify as it after first channel
#         self.bn1 = nn.BatchNorm2d(64)
#         self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
#         self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
#         self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
#         self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
#         self.linear = nn.Linear(512*block.expansion, num_classes)
        
        
#     def _make_layer(self, block, channels, num_blocks, stride):
#         strides = [stride] + [1]*(num_blocks-1)
#         layers = []
        
#         for stride in strides:
#             layers.append(block(self.in_channels, channels, stride))
#             self.in_channels = channels * block.expansion
#         return nn.Sequential(*layers)
    
#     def forward(self, x):
#         out = nn.ReLU()(self.bn1(self.conv1(x)))
#         out = self.layer1(out)
#         out = self.layer2(out)
#         out = self.layer3(out)
#         out = self.layer4(out)
#         out = nn.AdaptiveAvgPool2d(1)(out)
#         out = out.view(out.size(0), -1)
#         out = self.linear(out)
#         return out

In [14]:
# Tried new architecture for ResNet-18
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes, grayscale):
        self.inplanes = 64
        if grayscale:
            in_dim = 1
        else:
            in_dim = 3
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2. / n)**.5)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # because MNIST is already 1x1 here:
        # disable avg pooling
        #x = self.avgpool(x)
        
        x = x.view(x.size(0), -1)
        logits = self.fc(x)
        probas = F.softmax(logits, dim=1)
        return logits, probas

def resnet18(num_classes):
    """Constructs a ResNet-18 model."""
    model = ResNet(block=BasicBlock, 
                   layers=[2, 2, 2, 2],
                   num_classes=config['data']['num_classes'],
                   grayscale=True)
    return model

In [26]:
# model = ResNet(
#     BasicBlock, 
#     num_blocks = [2, 2, 2, 2], 
#     in_channels = config['data']['in_channels'],
#     num_classes = config['data']['num_classes']
# ).to(device)
model = resnet18(10).to(device)

## Training

In [16]:
# helper functions for training
def calculate_acc(logits, labels):
    """
    Given logits and correct labels:
    Return number of corrections and accuracy
    """
    pred = logits.argmax(dim=1, keepdim=True)
    correct = pred.eq(labels.view_as(pred)).sum().detach().cpu().item()
    return correct, 100 * correct / len(inputs)

In [32]:
criterion = nn.CrossEntropyLoss()
if config['training']['optimizer'] == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01) 
elif config['training']['topmizer'] == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer

SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.01
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0
)

In [18]:
# TRAINING

# Assitant List for showing live loss/acc
live_losses = []
live_accs = []

epoch_describer = tqdm(range(config['training']['epochs']), desc=f"Epoch {1}", ncols=100)

training_losses = []
training_accs = []
validation_losses = []
validation_accs = []
    
for epoch in epoch_describer:
    
    training_loss = 0.0
    training_corr = 0.0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        outputs,_ = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        correct, accuracy = calculate_acc(outputs, labels)
        
        training_corr += correct
        training_loss += loss.detach().cpu().item()
        
        live_losses.append(loss.detach().cpu().item())
        live_accs.append(accuracy)
        if len(live_losses) > 100:
            live_losses.pop(0)
            live_accs.pop(0)
            
        # update 
        epoch_describer.\
        set_description(f"Epoch {epoch+1} [loss:{np.mean(live_losses):.2f}, acc:{np.mean(live_accs):.2f}]")
        
    # record training data per epoch
    training_losses.append(training_loss / len(train_loader))
    training_accs.append(training_corr / len(train_loader.dataset))
        
    # validate the result
    valid_loss = 0.0
    valid_corr = 0.0
    
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        with torch.no_grad():
            outputs,_ = model(inputs)
            loss = criterion(outputs, labels)
            correct, accuracy = calculate_acc(outputs, labels)
            
            valid_loss += loss.detach().cpu().item()
            valid_corr += correct
            
    validation_losses.append(valid_loss / len(test_loader))
    validation_accs.append(valid_corr / len(test_loader.dataset))                                         

Epoch 5 [loss:0.12, acc:96.83]: 100%|█████████████████████████████████| 5/5 [01:57<00:00, 23.56s/it]


In [33]:
# train alternate ResNet-18
import time
LEARNING_RATE = 0.01
BATCH_SIZE = 128
NUM_EPOCHS = 5

# Architecture
NUM_FEATURES = 28*28
NUM_CLASSES = 10

# Other
DEVICE = "cuda:0"
GRAYSCALE = True

def compute_accuracy(model, data_loader, device):
    correct_pred, num_examples = 0, 0
    for i, (features, targets) in enumerate(data_loader):
            
        features = features.to(device)
        targets = targets.to(device)

        logits, probas = model(features)
        _, predicted_labels = torch.max(probas, 1)
        num_examples += targets.size(0)
        correct_pred += (predicted_labels == targets).sum()
    return correct_pred.float()/num_examples * 100
    

start_time = time.time()
for epoch in range(NUM_EPOCHS):
    
    model.train()
    for batch_idx, (features, targets) in enumerate(train_loader):
        
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)
            
        ### FORWARD AND BACK PROP
        logits, probas = model(features)
        cost = F.cross_entropy(logits, targets)
        optimizer.zero_grad()
        
        cost.backward()
        
        ### UPDATE MODEL PARAMETERS
        optimizer.step()
        
        ### LOGGING
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' 
                   %(epoch+1, NUM_EPOCHS, batch_idx, 
                     len(train_loader), cost))

        

    model.eval()
    with torch.set_grad_enabled(False): # save memory during inference
        print('Epoch: %03d/%03d | Train: %.3f%%' % (
              epoch+1, NUM_EPOCHS, 
              compute_accuracy(model, train_loader, device=DEVICE)))
        
    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))
    
print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

Epoch: 001/005 | Batch 0000/0118 | Cost: 2.3367
Epoch: 001/005 | Batch 0050/0118 | Cost: 0.2200
Epoch: 001/005 | Batch 0100/0118 | Cost: 0.1870
Epoch: 001/005 | Train: 96.558%
Time elapsed: 0.50 min
Epoch: 002/005 | Batch 0000/0118 | Cost: 0.1186
Epoch: 002/005 | Batch 0050/0118 | Cost: 0.1436
Epoch: 002/005 | Batch 0100/0118 | Cost: 0.0967
Epoch: 002/005 | Train: 97.928%
Time elapsed: 1.01 min
Epoch: 003/005 | Batch 0000/0118 | Cost: 0.1070
Epoch: 003/005 | Batch 0050/0118 | Cost: 0.0504
Epoch: 003/005 | Batch 0100/0118 | Cost: 0.0873
Epoch: 003/005 | Train: 98.432%
Time elapsed: 1.48 min
Epoch: 004/005 | Batch 0000/0118 | Cost: 0.0550
Epoch: 004/005 | Batch 0050/0118 | Cost: 0.0547
Epoch: 004/005 | Batch 0100/0118 | Cost: 0.0427
Epoch: 004/005 | Train: 99.228%
Time elapsed: 1.92 min
Epoch: 005/005 | Batch 0000/0118 | Cost: 0.0258
Epoch: 005/005 | Batch 0050/0118 | Cost: 0.0497
Epoch: 005/005 | Batch 0100/0118 | Cost: 0.0422
Epoch: 005/005 | Train: 99.533%
Time elapsed: 2.44 min
Total

In [34]:
training_accs

[0.6048833333333333,
 0.90955,
 0.9425333333333333,
 0.9578166666666666,
 0.9680833333333333]

In [35]:
validation_accs

[0.8809, 0.9307, 0.9473, 0.9555, 0.9611]

In [36]:
# Code imported form convrfm/cnfa_verification/pretrained_conv_nfa.py
import torch
import torch.nn as nn
import random
import numpy as np
from functorch import jacrev, vmap
from torch.nn.functional import pad
from numpy.linalg import eig
from copy import deepcopy
from torch.linalg import norm
from torchvision import models
import torchvision
from torchvision import transforms


SEED = 2323

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)


def patchify(x, patch_size, stride_size, pad_type='zeros'):
    q1, q2 = patch_size
    s1, s2 = stride_size

    pad_1 = (q1-1)//2
    pad_2 = (q2-1)//2

    pad_dims = (pad_2, pad_2, pad_1, pad_1)
    if pad_type == 'zeros':
        x = pad(x, pad_dims)
    elif pad_type == 'circular':
        x = pad(x, pad_dims, 'circular')

    patches = x.unfold(2, q1, s1).unfold(3, q2, s2)
    patches = patches.transpose(1, 3).transpose(1, 2)
    return patches


class PatchConvLayer(nn.Module):

    def __init__(self, conv_layer):
        super().__init__()
        self.layer = conv_layer

    def forward(self, patches):
        out = torch.einsum('nwhcqr, kcqr -> nwhk', patches, self.layer.weight)
        n, w, h, k = out.shape
        out = out.transpose(1, 3).transpose(2, 3)
        return out


class PatchBasicBlock(nn.Module):

    def __init__(self, block_layer, downsample=False):
        super().__init__()
        self.layer = block_layer
        self.downsample = downsample

    # x is patches instead of images
    def forward(self, X):
        x, y = X
        ops = self.layer
        _, _, _, _, q, s = y.shape
        z = y[:, :, :, :, (q-1)//2, (s-1)//2]
        z = z.transpose(1, 3).transpose(2, 3)

        s1, s2 = ops.conv1.layer.stride
        x = x[:, ::s1, ::s2, :, :, :]
        o = ops.conv1(x).contiguous()
        o = ops.bn1(o)
        o = ops.relu(o)
        o = ops.conv2(o)
        o = ops.bn2(o)

        if self.downsample:
            z = ops.downsample(z)
        o += z
        o = ops.relu(o)
        return o


class PatchBottleneck(nn.Module):

    def __init__(self, block_layer, downsample=False):
        super().__init__()
        self.layer = block_layer
        self.downsample = downsample

    # x is patches instead of images
    def forward(self, X):
        x, y = X
        ops = self.layer
        _, _, _, _, q, s = x.shape

        z = y[:, :, :, :, (q-1)//2, (s-1)//2]
        z = z.transpose(1, 3).transpose(2, 3)

        s1, s2 = ops.conv1.layer.stride
        x = x[:, ::s1, ::s2, :, :, :]
        o = ops.conv1(x).contiguous()
        o = ops.bn1(o)
        o = ops.relu(o)

        o = ops.conv2(o)
        o = ops.bn2(o)
        o = ops.relu(o)

        o = ops.conv3(o)
        o = ops.bn3(o)

        if self.downsample:
            z = ops.downsample(z)
        o += z
        o = ops.relu(o)
        return o


def get_jacobian(net, data, c_idx=0, chunk=100):
    count = 0
    with torch.no_grad():
        def single_net(x):
            o = net(x)[:,c_idx*chunk:(c_idx+1)*chunk]
            return o
        return vmap(jacrev(single_net))(data)


def egop(model, X):
    ajop = 0
    c = 1000
    chunk_idxs = 100
    chunk = c // chunk_idxs
    for i in range(chunk_idxs):
        J = get_jacobian(model, X, c_idx=i, chunk=chunk)[0]
        J = J[0, 0].transpose(0, 1)
        # n is number of images
        # c is number of channels
        # w, h give number of total patches
        n, c, w, h, _, _, _ = J.shape
        J = J.transpose(1, 3).transpose(1, 2)

        grads = J.reshape(n*w*h, c, -1)
        ajop += torch.einsum('ncd, ncD -> dD', grads, grads)
    return ajop


def load_nn(net, init_net,
            block_idx=0):

    patchnet = deepcopy(net)
    l_idx = block_idx
    layer_idx = 0
    subnet = net[:l_idx]
    for m in subnet.children():
        if isinstance(m, nn.Conv2d):
            layer_idx += 1
        # For Resnet18, 34
        elif isinstance(m, torchvision.models.resnet.BasicBlock):
        # For ResNet50 -> 152
        #elif isinstance(m, torchvision.models.resnet.Bottleneck):
            modules = [mod for mod in m.modules() if not isinstance(mod, nn.Sequential)]
            for mod in modules:
                if isinstance(mod, nn.Conv2d):
                    layer_idx += 1

    patchnet = patchnet[l_idx:]
    if block_idx == 0:
        patchnet[0] = PatchConvLayer(patchnet[0])
    else:

        if patchnet[0].downsample is not None:
            downsample = True
        else:
            downsample = False

        # For Resnet18, 34
        patchnet[0].conv1 = PatchConvLayer(patchnet[0].conv1)
        patchnet[0] = PatchBasicBlock(patchnet[0], downsample=downsample)

        # For Resnet50 -> 152
        #patchnet[0].conv1 = PatchConvLayer(patchnet[0].conv1)
        #patchnet[0] = PatchBottleneck(patchnet[0], downsample=downsample)

    count = -1
    for idx, p in enumerate(net.parameters()):
        if len(p.shape) > 1:
            count += 1
        if count == layer_idx:
            M = p.data
            _, ki, q, s = M.shape

            M0 = [p for p in init_net.parameters()][idx].data

            M = M.reshape(-1, ki*q*s)
            M = torch.einsum('nd, nD -> dD', M, M)

            M0 = M0.reshape(-1, ki*q*s)
            M0 = torch.einsum('nd, nD -> dD', M0, M0)
            break

    return net, patchnet, M, M0, l_idx, (q, s)


def get_grads(net, patchnet, trainloader,
              kernel=(3,3),
              layer_idx=0):
    net.eval()
    net.cuda()
    patchnet.eval()
    patchnet.cuda()
    M = 0
    q,s = kernel

    # Num images for taking AGOP (set to >100 for deeper layers)
    MAX_NUM_IMGS = 10
    for idx, batch in enumerate(trainloader):
        print("Computing GOP for sample " + str(idx) + \
              " out of " + str(MAX_NUM_IMGS))
        imgs, _ = batch
        with torch.no_grad():
            imgs = imgs.cuda()
            imgs = net[:layer_idx](imgs).cpu()
        patches = patchify(imgs, (q, s), (1, 1))
        p_copy = deepcopy(patches)
        patches = patches.cuda()
        p_copy = p_copy.cuda()

        M += egop(patchnet, [patches.unsqueeze(0), p_copy.unsqueeze(0)]).cpu()
        del imgs, patches
        torch.cuda.empty_cache()
        if idx >= MAX_NUM_IMGS:
            break
    net.cpu()
    patchnet.cpu()
    return M


def min_max(M):
    return (M - M.min()) / (M.max() - M.min())


def correlation(M1, M2):
    M1 -= M1.mean()
    M2 -= M2.mean()

    norm1 = norm(M1.flatten())
    norm2 = norm(M2.flatten())
    M1 = M1.to(device)
    M2.to(device)
    
    return torch.sum(M1 * M2) / (norm1 * norm2)


def verify_NFA(net, init_net, trainloader, layer_idx=0):

    net, patchnet, M, M0, l_idx, (q, s) = load_nn(net,
                                                  init_net,
                                                  block_idx=layer_idx)
  
    i_val = correlation(M0, M)
    print("Correlation between Initial and Trained CNFM: ", i_val)

    G = get_grads(net, patchnet, trainloader, kernel=(q, s), layer_idx=l_idx)
    
    r_val = correlation(M, G)

    print("Correlation between Trained CNFM and AGOP: ", r_val)
    print("Final: ", i_val, r_val)

    return i_val.data.numpy(), r_val.data.numpy()


def subroutine_unroll_net(net):
    modules = list(net.children())
    unrolled = []
    for m in modules:
        if isinstance(m, nn.Sequential):
            unrolled += subroutine_unroll_net(m)
        else:
            unrolled.append(m)
    return unrolled


def unroll_net(net):
    modules = subroutine_unroll_net(net)[:-1]
    modules += [nn.Flatten(), list(net.children())[-1]]
    net = nn.Sequential(*modules)
    return net

    

In [37]:
# TODO: Fix error
idxs = list(range(4, 12))

fname = 'test.csv'
outf = open(fname, 'w')

net = model
init_net = resnet18(10)

net = unroll_net(net)
init_net = unroll_net(init_net)

# Set path to imagenet data

# Batch size should be 1 to avoid issues with grads for skip connections

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # specific for MNIST
])
train_dataset = torchvision.datasets.MNIST(
    root = config['data']['path'],
    train = True,
    transform = transform,
    download = False
)
train_loader = torch.utils.data.DataLoader(
    dataset = train_dataset,
    batch_size = config['data']['batch_size'],
    shuffle = True
)

for idx in idxs:
    i_val, r_val = verify_NFA(net, init_net, train_loader, layer_idx=idx)
    print("Layer " + str(idx+1) + ',' + str(i_val) + ',' + str(r_val), file=outf, flush=True)

Correlation between Initial and Trained CNFM:  tensor(0.0998, device='cuda:0')
Computing GOP for sample 0 out of 10


  warn_deprecated('jacrev')
  warn_deprecated('vmap', 'torch.vmap')


RuntimeError: Given input size: (512x1x1). Calculated output size: (512x-5x-5). Output size is too small

In [38]:
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
# print("Optimizer's state_dict:")
# for var_name in optimizer.state_dict():
#     print(var_name, "\t", optimizer.state_dict()[var_name])

Model's state_dict:
conv1.weight 	 torch.Size([64, 1, 7, 7])
bn1.weight 	 torch.Size([64])
bn1.bias 	 torch.Size([64])
bn1.running_mean 	 torch.Size([64])
bn1.running_var 	 torch.Size([64])
bn1.num_batches_tracked 	 torch.Size([])
layer1.0.conv1.weight 	 torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight 	 torch.Size([64])
layer1.0.bn1.bias 	 torch.Size([64])
layer1.0.bn1.running_mean 	 torch.Size([64])
layer1.0.bn1.running_var 	 torch.Size([64])
layer1.0.bn1.num_batches_tracked 	 torch.Size([])
layer1.0.conv2.weight 	 torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight 	 torch.Size([64])
layer1.0.bn2.bias 	 torch.Size([64])
layer1.0.bn2.running_mean 	 torch.Size([64])
layer1.0.bn2.running_var 	 torch.Size([64])
layer1.0.bn2.num_batches_tracked 	 torch.Size([])
layer1.1.conv1.weight 	 torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight 	 torch.Size([64])
layer1.1.bn1.bias 	 torch.Size([64])
layer1.1.bn1.running_mean 	 torch.Size([64])
layer1.1.bn1.running_var 	 torch.Size([64])
layer1.1.bn1.num_batc

In [41]:
p = load_nn(net,init_net,0)[1]

In [None]:
egop(p,)