In [1]:
import sys

sys.path.append('../')

In [2]:
import pickle
from glob import glob
import os
import yaml
from easydict import EasyDict as edict
import numpy as np
import torch

In [3]:
from model.resnet_xor import *
from model.resnet import *

In [4]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split, ConcatDataset

In [5]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
test_dataset = datasets.CIFAR10(root='../xor_neuron_data/data',
                                train=False,
                                transform=transform,
                                download=True)

Files already downloaded and verified


In [6]:
test_loader = DataLoader(dataset=test_dataset,
                                 batch_size=16,
                                 shuffle=False)

In [7]:
imgs, target = next(iter(test_loader))

In [8]:
config_file = glob('../config/resnet/0207_resnet_quad_v2_rtx30/4.yaml')[0]
config = edict(yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))

control_config_file = glob('../config/control_model/resnet20.yaml')[0]
control_config = edict(yaml.load(open(control_config_file, 'r'), Loader=yaml.FullLoader))

In [9]:
model = ResNet20_Xor(config)

In [10]:
model

ResNet_Xor(
  (loss_func): CrossEntropyLoss()
  (inner_net): QuadraticInnerNet(
    (A): Linear(in_features=2, out_features=2, bias=False)
    (b): Linear(in_features=2, out_features=1, bias=True)
  )
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock_InnerNet(
      (inner_net): QuadraticInnerNet(
        (A): Linear(in_features=2, out_features=2, bias=False)
        (b): Linear(in_features=2, out_features=1, bias=True)
      )
      (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequent

In [11]:
out, loss, _ = model(imgs, target)

In [12]:
out

tensor([[ 1.4446e-01, -4.6949e-01,  6.4128e-01, -8.9683e-01, -1.6483e-01,
          2.6049e-01,  1.8021e-02,  3.8705e-01, -1.9203e-01,  2.6809e-01],
        [ 2.4389e-02, -7.7869e-01,  6.8173e-01, -7.7010e-01, -1.5189e-01,
         -1.1013e-01, -7.1625e-02,  7.2703e-01, -4.4885e-01,  3.2069e-01],
        [ 5.6768e-02, -6.3377e-01,  5.8100e-01, -8.1527e-01, -1.0849e-01,
          1.0627e-01, -2.1905e-02,  5.7978e-01, -3.5459e-01,  2.6987e-01],
        [ 1.2265e-01, -6.4515e-01,  7.7054e-01, -9.5554e-01, -2.0009e-01,
          5.9910e-02, -3.8645e-02,  5.9949e-01, -3.1011e-01,  2.5074e-01],
        [ 1.4394e-01, -6.0458e-01,  4.4935e-01, -8.6440e-01, -1.9525e-01,
          1.6387e-01,  3.6469e-02,  4.5809e-01, -1.4029e-01,  2.3125e-01],
        [ 8.4479e-02, -6.0649e-01,  6.8864e-01, -9.1188e-01, -2.3811e-01,
          2.2694e-01,  3.8692e-02,  4.7409e-01, -1.8935e-01,  3.2752e-01],
        [-3.4333e-02, -8.0641e-01,  4.1250e-01, -6.9732e-01, -2.5704e-01,
          1.3402e-01,  3.6697e-0

In [13]:
loss

tensor(3.0155, grad_fn=<NllLossBackward>)

In [14]:
origin_resnet = resnet20(control_config)

In [15]:
origin_resnet

ResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation_fnc): ReLU()
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, e