In [1]:
from collections import OrderedDict
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import rand, matmul, diag, einsum
from time import time

In [2]:
import sys

sys.path.append('../')

In [3]:
from model.resnet import *

In [4]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split, ConcatDataset

In [5]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
test_dataset = datasets.CIFAR10(root='../xor_neuron_data/data',
                                train=False,
                                transform=transform,
                                download=False)

In [20]:
batch_size = 16
num_channel = 16
arg_in_dim = 2
num_cell_types = 1
in_hidden_dim = 64


mlp_inputs = rand(batch_size,128)
conv_inputs = rand(batch_size, num_channel, 64, 64)

In [21]:
test_loader = DataLoader(dataset=test_dataset,
                                 batch_size=batch_size,
                                 shuffle=False)

imgs, target = next(iter(test_loader))

In [22]:
inner_net = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(arg_in_dim, in_hidden_dim)),
            ('relu1', nn.ReLU()),
            ('fc2', nn.Linear(in_hidden_dim, in_hidden_dim)),
            ('relu2', nn.ReLU()),
            ('fc3', nn.Linear(in_hidden_dim, 1))]))

In [23]:
mlp_outputs = inner_net(mlp_inputs.reshape(batch_size, -1, arg_in_dim)).squeeze()

In [24]:
conv_inputs.reshape(batch_size, num_channel//arg_in_dim, -1, arg_in_dim).shape

torch.Size([16, 8, 4096, 2])

In [25]:
conv_outputs = inner_net(conv_inputs.reshape(batch_size, num_channel//arg_in_dim, -1, arg_in_dim)).squeeze()
conv_outputs = conv_outputs.reshape(batch_size, -1, 64, 64)

In [26]:
print(mlp_inputs.shape)
print(mlp_outputs.shape)

print(conv_inputs.shape)
print(conv_outputs.shape)

torch.Size([16, 128])
torch.Size([16, 64])
torch.Size([16, 16, 64, 64])
torch.Size([16, 8, 64, 64])


In [110]:
class QuadraticInnerNet(nn.Module):
    def __init__(self):
        super(QuadraticInnerNet, self).__init__()
        self.A = nn.Linear(2, 2, bias=False)
        self.b = nn.Linear(2, 1, bias=False)

    def forward(self, x):
        if len(x.shape) == 2:
            inputs = x.reshape(x.shape[0], -1, 2)
            
            x_123 = self.A(inputs)
            x_45 = self.b(inputs)
            
            out = einsum('ijk, ijk -> ij', x_123, inputs).unsqueeze(dim=-1)
            out += x_45
            out = out.squeeze()
            
        elif len(x.shape) == 4:
            inputs = x.reshape(x.shape[0], x.shape[1]//2, -1, 2)
            
            x_123 = self.A(inputs)
            x_45 = self.b(inputs)
            
            out = einsum('ijkh, ijkh -> ijk', x_123, inputs).unsqueeze(dim=-1)
            out += x_45
            out = out.squeeze()
            out = out.reshape(x.shape[0], -1, x.shape[-1], x.shape[-1])
            
        return out

In [111]:
quad_inner = QuadraticInnerNet()

In [112]:
quad_inner

QuadraticInnerNet(
  (A): Linear(in_features=2, out_features=2, bias=False)
  (b): Linear(in_features=2, out_features=1, bias=False)
)

In [113]:
conv_inputs.shape

torch.Size([16, 16, 64, 64])

In [114]:
mlp_inputs.shape

torch.Size([16, 128])

In [115]:
input_innernet = conv_inputs.reshape(batch_size, num_channel//arg_in_dim, -1, arg_in_dim)

In [116]:
conv_outputs = quad_inner(conv_inputs)

In [117]:
conv_outputs.shape

torch.Size([16, 8, 64, 64])

In [118]:
mlp_inputs.reshape(batch_size, -1, arg_in_dim).shape

torch.Size([16, 64, 2])

In [119]:
mlp_outputs = quad_inner(mlp_inputs)

In [120]:
mlp_outputs.shape

torch.Size([16, 64])