<a href="https://colab.research.google.com/github/AlexCuozzo/SpatialFeatureLayer/blob/master/ConvolutionWithoutWeightTransport.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Convolutional Layer Without Weight Transport

## What is weight transport?
Weight transport is a common reason to reject a type of neural network model as being biologically implausible. It's any time that weights are reused or accessed across multiple layers, used multiple times within a layer, or are generally used in anything other than a local operation. The grounds for being able to claim biological implausibility is that biological neural networks don't have numbers floating around that can be used multiple places - everything is a physical connection that can only be modified by local information.


## A step in the right direction

It has been shown that the human visual cortex maintains a retinotopic map, meaning that visual imputs maintain spatial relationships as they are processed. The convolutional layer captures this behavior, as it constructs feature maps that align with the input. In addition, experiments like Hubel and Weisel have shown that there are feature selector neurons in the brain, and these features get more and more complex as they ascend the cortical hierarchy. Again, the convolutional layer captures this behavior with its use of filters. Not too bad. However, convolutional layers even still take advantage of weight transport, because they work through convolving a filter. The convolution operation is inherently nonlocal, so this is just a simple way of fixing it.



## The proposal

Similar to a convolutional layer, this layer consists of many "filters", but they don't move. They are wired up statically and updated just like a dense network, but nodes in a vertical stack all map to the same volume in the level below.

I'll call it a SpatialFeature layer.

In [None]:
import torch
import torch.nn as nn
import math
import warnings
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from torch.nn import init

In [None]:
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
from typing import Optional, List, Tuple, Union, T

In [None]:
class SpatialFeatureLayer(nn.Module):
    def __init__(self,
                 in_dimensions: Tuple[T, T, T],
                 out_channels: int, 
                 kernel_size: Tuple[T, T], 
                 bias: bool = True,):
        super(SpatialFeatureLayer, self).__init__()
        self.in_channels, self.in_height, self.in_width = in_dimensions
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.output_height = self.in_height - kernel_size[0] + 1
        self.output_width =  self.in_width - kernel_size[1] + 1
        self.weight = Parameter(torch.Tensor(self.in_channels, self.in_height, self.in_width, self.out_channels, self.output_height, self.output_width))
        if bias:
            self.bias = Parameter(torch.Tensor(self.out_channels, self.output_height, self.output_width))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        init.kaiming_uniform_(self.bias, a=math.sqrt(5))
        init.zeros_(self.weight)
        for ow in range(self.weight.size(5)):
          for oh in range(self.weight.size(4)):
            for oc in range(self.weight.size(3)):
              init.kaiming_uniform(self.weight[:, oh:oh+self.kernel_size[0], ow:ow+self.kernel_size[1], oc, oh, ow])
    
    def forward(self, x):
      x = torch.tensordot(x, self.weight, dims=([1, 2, 3], [0, 1, 2]))
      x = x + self.bias
      return x
        

# Training a model with the layer

In [None]:
import os

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Normalize the test set same as training set without augmentation
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [None]:
data_path = "./data"
trainset = torchvision.datasets.CIFAR10(
    root=data_path, train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root=data_path, train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
class RegularConvNet(nn.Module):
    def __init__(self):
        super(RegularConvNet, self).__init__()
        self.conv_layer = nn.Sequential(
            # Conv Layer block 1
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Conv Layer block 2
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout2d(p=0.05),
            # Conv Layer block 3
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.fc_layer = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(4096, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Linear(512, 10)
        )
    def forward(self, x):
        """Perform forward."""
        x = self.conv_layer(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layer(x)
        return x

In [None]:
class SpatialDenseNet(nn.Module):
    def __init__(self):
        super(SpatialDenseNet, self).__init__()
        self.conv_layer = nn.Sequential(
            SpatialFeatureLayer((3, 32, 32), 10, kernel_size=(3, 3)),
            nn.BatchNorm2d(10),
            nn.ReLU(inplace=True),
            SpatialFeatureLayer((10, 30, 30), 20, kernel_size=(3, 3)),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Conv Layer block 2
            SpatialFeatureLayer((20, 14, 14), 32, kernel_size=(3, 3)),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            SpatialFeatureLayer((32, 12, 12), 64, kernel_size=(3, 3)),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout2d(p=0.05),
            # Conv Layer block 3
            SpatialFeatureLayer((64, 5, 5), 128, kernel_size=(2, 2)),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            SpatialFeatureLayer((128, 4, 4), 256, kernel_size=(2, 2)),
            nn.ReLU(inplace=True),
        )
        self.fc_layer = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(2304, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Linear(512, 10)
        )
    def forward(self, x):
        """Perform forward."""
        x = self.conv_layer(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layer(x)
        return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
spatial_dense = SpatialDenseNet().to(device)



In [None]:
spatial_dense(inputs).size()

torch.Size([64, 2304])

In [None]:
criterion = nn.CrossEntropyLoss()
spatial_dense = spatial_dense.to(device)
optimizer = torch.optim.Adam(spatial_dense.parameters(), lr=0.0003)
num_epochs = 5
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(trainloader, 0):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = spatial_dense(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.data.item()
    running_loss /= len(trainloader)
    print("Epoch: {0} | Loss: {1}".format(epoch+1, running_loss))
if not os.path.isdir('./checkpoint'):
    os.mkdir('checkpoint')
torch.save(spatial_dense.state_dict, "./checkpoint/spatialDense")

Epoch: 1 | Loss: 1.8816450436401855
Epoch: 2 | Loss: 1.6222480977587688
Epoch: 3 | Loss: 1.539194023975021
Epoch: 4 | Loss: 1.4816871767153825
Epoch: 5 | Loss: 1.4341203070357633


In [None]:
regular_conv = RegularConvNet()

In [None]:
criterion = nn.CrossEntropyLoss()
regular_conv = regular_conv.to(device)
optimizer = torch.optim.Adam(regular_conv.parameters(), lr=0.0003)
num_epochs = 5
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(trainloader, 0):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = regular_conv(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.data
    running_loss /= len(trainloader)
    print("Epoch: {0} | Loss: {1}".format(epoch+1, running_loss))
if not os.path.isdir('./checkpoint'):
    os.mkdir('checkpoint')
torch.save(regular_conv.state_dict, "./checkpoint/regularConv")

NameError: ignored

In [None]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.sl = SpatialFeatureLayer((3, 32, 32), 5, (3, 3))
        self.fc = nn.Linear(5*30*30, 10)
    def forward(self, x):
        x = self.sl(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
criterion = nn.CrossEntropyLoss()
model = SimpleModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)
num_epochs = 5
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(trainloader, 0):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.data.item()
    running_loss /= len(trainloader)
    print("Epoch: {0} | Loss: {1}".format(epoch+1, running_loss))



Epoch: 1 | Loss: 2.694465982639576
Epoch: 2 | Loss: 2.0974156254392757
Epoch: 3 | Loss: 2.073304990673309
Epoch: 4 | Loss: 2.076695164451209
Epoch: 5 | Loss: 2.074775897328506


In [None]:
def test_model(model, test_loader):
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # test model
    model.to(device)
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model.forward(images)
            _, predictions = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predictions == labels).sum().item()
        print("Accuracy: {} %".format(100 * correct / total))

In [None]:
test_model(regular_conv, testloader)

Accuracy: 86.9 %


In [None]:
test_model(spatial_dense, testloader)

Accuracy: 52.73 %


# Take aways


This is biologically plausible, but not very useful. This definitely needs to be more optimized for memory as it is currently very intensive. As far as machine learning methods go, it works, but probably not as well as other methods like the Conv2d layer. The reason is that there are just so many more weights to tune, so training takes a while. From a practical standpoint, this is probably not recommended.