In [1]:
import sys

sys.path.append('../')

In [2]:
import pickle
from glob import glob
import os
import yaml
from easydict import EasyDict as edict
import numpy as np
import torch

In [3]:
from model.resnet_xor import *
from model.resnet import *

In [4]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split, ConcatDataset

In [5]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
test_dataset = datasets.CIFAR10(root='../xor_neuron_data/data',
                                train=False,
                                transform=transform,
                                download=True)

Files already downloaded and verified


In [6]:
test_loader = DataLoader(dataset=test_dataset,
                                 batch_size=16,
                                 shuffle=False)

In [7]:
imgs, target = next(iter(test_loader))

In [8]:
# config_file = glob('../config/resnet/0207_resnet_quad_v2_rtx30/4.yaml')[0]
# config = edict(yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))

config_file = glob('../config/resnet/xor_resnet.yaml')[0]
config = edict(yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))

control_config_file = glob('../config/control_model/resnet20.yaml')[0]
control_config = edict(yaml.load(open(control_config_file, 'r'), Loader=yaml.FullLoader))

In [9]:
config.model.num_blocks = [2,2,2,2]

In [10]:
config.model.name = 'resnet18'
config.model.inner_net = 'quad'

In [11]:
model = ResNet20_Xor(config)

In [12]:
model

ResNet_Xor(
  (loss_func): CrossEntropyLoss()
  (inner_net): QuadraticInnerNet(
    (A): Linear(in_features=2, out_features=2, bias=False)
    (b): Linear(in_features=2, out_features=1, bias=True)
    (relu): ReLU(inplace=True)
    (batch_norm): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock_InnerNet(
      (inner_net): QuadraticInnerNet(
        (A): Linear(in_features=2, out_features=2, bias=False)
        (b): Linear(in_features=2, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (batch_norm): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, 

In [13]:
out, loss, _ = model(imgs, target)

torch.Size([16, 16, 1024, 2])
torch.Size([16, 16, 1024])
torch.Size([16, 16, 1024, 2])
torch.Size([16, 16, 1024])
torch.Size([16, 16, 1024, 2])
torch.Size([16, 16, 1024])
torch.Size([16, 16, 1024, 2])
torch.Size([16, 16, 1024])
torch.Size([16, 16, 1024, 2])
torch.Size([16, 16, 1024])
torch.Size([16, 16, 1024, 2])
torch.Size([16, 16, 1024])
torch.Size([16, 16, 1024, 2])
torch.Size([16, 16, 1024])
torch.Size([16, 32, 256, 2])
torch.Size([16, 32, 256])
torch.Size([16, 32, 256, 2])
torch.Size([16, 32, 256])
torch.Size([16, 32, 256, 2])
torch.Size([16, 32, 256])
torch.Size([16, 32, 256, 2])
torch.Size([16, 32, 256])
torch.Size([16, 32, 256, 2])
torch.Size([16, 32, 256])
torch.Size([16, 32, 256, 2])
torch.Size([16, 32, 256])
torch.Size([16, 64, 64, 2])
torch.Size([16, 64, 64])
torch.Size([16, 64, 64, 2])
torch.Size([16, 64, 64])
torch.Size([16, 64, 64, 2])
torch.Size([16, 64, 64])
torch.Size([16, 64, 64, 2])
torch.Size([16, 64, 64])
torch.Size([16, 64, 64, 2])
torch.Size([16, 64, 64])
torch.

In [14]:
out

tensor([[ 8.2685e-02, -4.7324e-02,  1.8448e-02, -1.9922e-01,  7.4221e-02,
         -3.9740e-02, -7.9478e-02, -2.6982e-02, -2.8880e-02,  7.9810e-02],
        [ 8.2685e-02, -4.7324e-02,  1.8448e-02, -1.9922e-01,  7.4221e-02,
         -3.9740e-02, -7.9478e-02, -2.6982e-02, -2.8880e-02,  7.9810e-02],
        [ 8.2685e-02, -4.7324e-02,  1.8448e-02, -1.9922e-01,  7.4221e-02,
         -3.9740e-02, -7.9478e-02, -2.6982e-02, -2.8880e-02,  7.9810e-02],
        [ 8.2685e-02, -4.7324e-02,  1.8448e-02, -1.9922e-01,  7.4221e-02,
         -3.9740e-02, -7.9478e-02, -2.6982e-02, -2.8880e-02,  7.9810e-02],
        [ 8.2685e-02, -4.7324e-02,  1.8448e-02, -1.9922e-01,  7.4221e-02,
         -3.9740e-02, -7.9478e-02, -2.6982e-02, -2.8880e-02,  7.9810e-02],
        [ 8.2685e-02, -4.7324e-02,  1.8448e-02, -1.9922e-01,  7.4221e-02,
         -3.9740e-02, -7.9478e-02, -2.6982e-02, -2.8880e-02,  7.9810e-02],
        [-4.8800e+00,  1.6423e+00, -9.4857e+00,  1.8785e+01,  9.5898e+00,
         -4.7844e+00, -8.6375e+0

In [15]:
loss

tensor(3.2496, grad_fn=<NllLossBackward>)