<a href="https://colab.research.google.com/github/IanWangg/Brain-inspired-scale-invariant-CNN/blob/multi-column-network/Cone_and_rod.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
import torch.nn as nn
import torch.nn.functional as F
import torch

!pip install torchviz



###Download the Training Data

In [14]:
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import random

batch_size=64

train_transform = transforms.Compose([
    transforms.ToTensor()
])

train_data = dsets.CIFAR10(root = './data', train = True,
                        transform = train_transform, download = True)

train_gen = torch.utils.data.DataLoader(dataset = train_data,
                                             batch_size = batch_size,
                                             shuffle = True)


Files already downloaded and verified


In [15]:
type(train_data[0][0])
train_data[0][0].shape

torch.Size([3, 32, 32])

###Training and Testing Method

In [16]:
def train(net, lr=0.001, num_epochs=10, batch_size=64):  
  if torch.cuda.is_available():
    net.cuda()
  loss_function = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam( net.parameters(), lr=lr)

  for epoch in range(num_epochs):
    for i ,(images,labels) in enumerate(train_gen):
      if torch.cuda.is_available():
        images = images.cuda()
        labels = labels.cuda()
      
      optimizer.zero_grad()
      outputs = net(images)
      loss = loss_function(outputs, labels)
      loss.backward()
      optimizer.step()
      
      if (i+1) % 100 == 0:
        print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f'
                  %(epoch+1, num_epochs, i+1, len(train_data)//batch_size, loss.item()))
def test(net):  
  if(net.multiScale):
    print('RESULTS OF MULTISCALE CNN')
  else:
    print('RESULTS OF STANDARD CNN')
  if torch.cuda.is_available():
    net.cuda()
  correct = 0
  total = 0
  # loss_function = nn.CrossEntropyLoss()
  for images,labels in train_gen:
    if torch.cuda.is_available():
      images = images.cuda()
      labels = labels.cuda()
    
    output = net(images)
    # loss = loss_function(outputs, labels)
    _, predicted = torch.max(output,1)
    correct += (predicted == labels).sum()
    total += labels.size(0)
  train_acc = (100*correct.cpu().numpy())/(total+1)
  print('Train accuracy of the model: %.3f %%' %(train_acc))
  print(correct, total)

  correct = 0
  total = 0
  for images,labels in test_gen:
    if torch.cuda.is_available():
      images = images.cuda()
      labels = labels.cuda()
    
    output = net(images)
    # loss = loss_function(outputs, labels)
    _, predicted = torch.max(output,1)
    correct += (predicted == labels).sum()
    total += labels.size(0)
  test_acc = (100*correct.cpu().numpy())/(total+1)
  print('Test accuracy of the model: %.3f %%' %(test_acc))
  print(correct, total)

###Create the Gaussian Kernel

In [17]:
import math
import scipy

def create_gau_filter(size, sigma):
  x_mean = (size + 1) / 2
  y_mean = (size + 1) / 2
  pi = math.pi
  sum = 0
  matrix = [[0] * size for i in range(size)]
  for i in range(size):
    for j in range(size):
      matrix[i][j] = (1 / sigma**2 / 2 / pi) * math.exp(-((i - x_mean)**2 + (j - y_mean)**2) / 2 / sigma**2) / sigma**2 / 2
      sum += matrix[i][j]
  matrix = torch.Tensor(matrix)
  matrix = matrix / sum
  # print(torch.sum(matrix))
  return matrix

def calculate(x, y, sigma):
  return ((x**2 + y**2 - 2 * sigma**2) / sigma**4) * math.exp(-1 * (x**2 + y **2) / 2 * sigma**2)

def create_log_filter(size, sigma):
  matrix = [[0] * size for i in range(size)]
  for i in range(size):
    for j in range(size):
      matrix[i][j] = calculate(i - size // 2, j - size // 2, sigma)
  matrix = torch.Tensor(matrix)
  # print(matrix)
  # print(torch.sum(matrix))
  return matrix

create_log_filter(3, 1)

create_gau_filter(5, 1)

tensor([[2.2167e-05, 2.7005e-04, 1.2103e-03, 1.9954e-03, 1.2103e-03],
        [2.7005e-04, 3.2899e-03, 1.4744e-02, 2.4309e-02, 1.4744e-02],
        [1.2103e-03, 1.4744e-02, 6.6079e-02, 1.0895e-01, 6.6079e-02],
        [1.9954e-03, 2.4309e-02, 1.0895e-01, 1.7962e-01, 1.0895e-01],
        [1.2103e-03, 1.4744e-02, 6.6079e-02, 1.0895e-01, 6.6079e-02]])

###Define the Model Inspired By Cone and Rod Cell

In [18]:
conv = nn.Conv2d(3, 3, 5, 1, padding=1)
conv.weight.shape
x = torch.randn(3, 4, 5)
print(x, 'x')
x = torch.flatten(x, 1)
print(x, 'x')

y = torch.randn(3, 4, 5)
print(y, 'y')
y = torch.flatten(x, 1)
print(y, 'y')

tensor([[[-0.3483,  0.1918,  0.3405,  0.9588, -0.3387],
         [-0.2818, -0.8398,  0.0582, -0.2635, -1.7742],
         [ 0.1453,  0.9793, -1.4312, -0.3346,  1.2268],
         [ 1.4233,  0.8699, -0.2412,  2.1485, -0.7564]],

        [[-2.0579,  1.7260, -2.3870,  2.2155,  0.1215],
         [-0.2720,  0.3469, -0.3147, -0.9637, -0.0929],
         [-1.2641,  1.0472, -1.5853, -0.1638, -1.5156],
         [-0.8209, -1.9388,  0.7356,  0.4327, -1.8492]],

        [[ 1.0815,  0.2984, -0.3323,  0.5167, -0.0198],
         [-1.4208,  1.1921, -1.0203,  1.4290, -0.2296],
         [-0.5082,  2.2443,  0.9400, -0.4019,  0.0142],
         [ 0.6599, -1.3512,  0.2596,  0.9150,  1.6426]]]) x
tensor([[-0.3483,  0.1918,  0.3405,  0.9588, -0.3387, -0.2818, -0.8398,  0.0582,
         -0.2635, -1.7742,  0.1453,  0.9793, -1.4312, -0.3346,  1.2268,  1.4233,
          0.8699, -0.2412,  2.1485, -0.7564],
        [-2.0579,  1.7260, -2.3870,  2.2155,  0.1215, -0.2720,  0.3469, -0.3147,
         -0.9637, -0.0929, -1.2

In [22]:
import scipy

size = 3
sigma = 2
gau_kernel = create_gau_filter(size, sigma)
log_kernel = create_log_filter(size, sigma)

class Net(nn.Module):
    def __init__(self, multiScale=True):
        super(Net, self).__init__()
        self.multiScale = multiScale
        self.conv1_rod = nn.Conv2d(1, 64, 5, 1)
        self.conv2_rod = nn.Conv2d(64, 64, 5, 1)
        self.conv1_cone = nn.Conv2d(3, 64, 5, 1)
        self.conv2_cone = nn.Conv2d(64, 64, 5, 1)
        self.conv1_log = nn.Conv2d(1, 64, 5, 1)
        self.conv2_log = nn.Conv2d(64, 64, 5, 1)
        self.fc1_rod = nn.Linear(64 * 5 * 5, 384)
        self.fc1_cone = nn.Linear(64 * 5 * 5, 384)
        self.fc1_log = nn.Linear(64 * 5 * 5, 384)
        self.fc2_rod = nn.Linear(384, 192)
        self.fc2_cone = nn.Linear(384, 192)
        self.fc2_log = nn.Linear(384, 192)
        self.gaussian = nn.Conv2d(3, 1, size, 1, padding=(size - 1) // 2)
        self.log = nn.Conv2d(3, 1, size, 1, padding=(size - 1) // 2)
        self.act = torch.nn.LeakyReLU()
        if (self.multiScale):
          self.fc = nn.Linear(192 * 3, 10)
        else:
          self.fc = nn.Linear(192, 10)

    def forward(self, x):
        # print(self.fc.weight.shape)
        with torch.no_grad():
          self.gaussian.weight[:][:] = gau_kernel
          self.log.weight[:][:] = log_kernel
        rod = self.gaussian(x)
        cone = torch.Tensor.clone(x)
        log = self.log(x)
        # cone part
        # print(cone.shape, 'cone init')
        cone = self.conv1_cone(cone)
        cone = F.relu(cone)
        cone = F.max_pool2d(cone, 2)
        # print(cone.shape, 'cone conv1')
        cone = self.conv2_cone(cone)
        cone = F.relu(cone)
        cone = F.max_pool2d(cone, 2)
        # print(cone.shape, 'cone conv2')
        cone = torch.flatten(cone, 1)
        cone = self.fc1_cone(cone)
        cone = F.relu(cone)
        cone = self.fc2_cone(cone)
        cone = F.relu(cone)
        # print(cone.shape, 'cone')
        if (self.multiScale):
          # rod part
          # print(rod.shape, 'rod init')
          rod = self.conv1_rod(rod)
          rod = F.relu(rod)
          rod = F.max_pool2d(rod, 2)
          # print(rod.shape, 'rod conv1')
          rod = self.conv2_rod(rod)
          rod = F.relu(rod)
          rod = F.max_pool2d(rod, 2)
          # print(rod.shape, 'rod conv2')
          rod = torch.flatten(rod, 1)
          rod = self.fc1_rod(rod)
          rod = F.relu(rod)
          rod = self.fc2_rod(rod)
          rod = F.relu(rod)

          log = self.conv1_log(log)
          log = self.act(log)
          log = F.max_pool2d(log, 2)
          # print(rod.shape, 'rod conv1')
          log = self.conv2_log(log)
          log = self.act(log)
          log = F.max_pool2d(log, 2)
          # print(rod.shape, 'rod conv2')
          log = torch.flatten(log, 1)
          log = self.fc1_log(log)
          log = self.act(log)
          log = self.fc2_log(log)
          log = self.act(log)
          # print(rod.shape, 'rod')
          # concatenate part
          output = torch.cat([cone, rod, log], dim=1)
          # output = torch.max(cone, rod)
          # print(self.fc.weight.shape, output.shape)
        output = self.fc(output)
        # output = F.relu(output)
        # output = self.fc1(output)
        # output = F.relu(output)
        # output = self.fc2(output)
        output = F.log_softmax(output, dim=1)

        return output



net = Net(multiScale=True)
# train(net)

In [20]:
for name, named_parameter in net.named_parameters():
  if named_parameter.requires_grad:
    print(name, named_parameter.shape)

conv1_rod.weight torch.Size([64, 1, 5, 5])
conv1_rod.bias torch.Size([64])
conv2_rod.weight torch.Size([64, 64, 5, 5])
conv2_rod.bias torch.Size([64])
conv1_cone.weight torch.Size([64, 3, 5, 5])
conv1_cone.bias torch.Size([64])
conv2_cone.weight torch.Size([64, 64, 5, 5])
conv2_cone.bias torch.Size([64])
conv1_log.weight torch.Size([64, 1, 5, 5])
conv1_log.bias torch.Size([64])
conv2_log.weight torch.Size([64, 64, 5, 5])
conv2_log.bias torch.Size([64])
fc1_rod.weight torch.Size([384, 1600])
fc1_rod.bias torch.Size([384])
fc1_cone.weight torch.Size([384, 1600])
fc1_cone.bias torch.Size([384])
fc1_log.weight torch.Size([384, 1600])
fc1_log.bias torch.Size([384])
fc2_rod.weight torch.Size([192, 384])
fc2_rod.bias torch.Size([192])
fc2_cone.weight torch.Size([192, 384])
fc2_cone.bias torch.Size([192])
fc2_log.weight torch.Size([192, 384])
fc2_log.bias torch.Size([192])
gaussian.weight torch.Size([1, 3, 3, 3])
gaussian.bias torch.Size([1])
log.weight torch.Size([1, 3, 3, 3])
log.bias torch.

In [23]:
train(net, lr=0.0003, num_epochs=5)
train(net, lr=0.001, num_epochs=10)
train(net, lr=0.0003, num_epochs=5)

Epoch [1/5], Step [100/781], Loss: 1.8572
Epoch [1/5], Step [200/781], Loss: 1.8065
Epoch [1/5], Step [300/781], Loss: 1.6371
Epoch [1/5], Step [400/781], Loss: 1.5467
Epoch [1/5], Step [500/781], Loss: 1.5102
Epoch [1/5], Step [600/781], Loss: 1.5464
Epoch [1/5], Step [700/781], Loss: 1.5326
Epoch [2/5], Step [100/781], Loss: 1.2468
Epoch [2/5], Step [200/781], Loss: 1.0352
Epoch [2/5], Step [300/781], Loss: 1.2656
Epoch [2/5], Step [400/781], Loss: 1.1956
Epoch [2/5], Step [500/781], Loss: 1.1424
Epoch [2/5], Step [600/781], Loss: 1.3130
Epoch [2/5], Step [700/781], Loss: 1.1043
Epoch [3/5], Step [100/781], Loss: 1.0576
Epoch [3/5], Step [200/781], Loss: 1.1049
Epoch [3/5], Step [300/781], Loss: 1.2397
Epoch [3/5], Step [400/781], Loss: 0.7911
Epoch [3/5], Step [500/781], Loss: 1.2212
Epoch [3/5], Step [600/781], Loss: 0.9235
Epoch [3/5], Step [700/781], Loss: 0.7786
Epoch [4/5], Step [100/781], Loss: 0.8869
Epoch [4/5], Step [200/781], Loss: 0.9801
Epoch [4/5], Step [300/781], Loss:

KeyboardInterrupt: ignored

###32 -> 32

In [24]:
test_transform = transforms.Compose([
    transforms.CenterCrop(32),
    transforms.Resize(32, interpolation=2),
    transforms.ToTensor(),
])

test_data = dsets.CIFAR10(root = './data', train = False,
                       transform = test_transform, download=True)

test_gen = torch.utils.data.DataLoader(dataset = test_data,
                                      batch_size = batch_size, 
                                      shuffle = False)
print("Architecture:\n", "\nnet.Multiscale=",net.multiScale)
test(net)

Files already downloaded and verified
Architecture:
 
net.Multiscale= True
RESULTS OF MULTISCALE CNN
Train accuracy of the model: 99.998 %
tensor(50000, device='cuda:0') 50000
Test accuracy of the model: 72.703 %
tensor(7271, device='cuda:0') 10000


###32 -> 28

In [25]:
test_transform = transforms.Compose([
    transforms.CenterCrop(28),
    transforms.Resize(32, interpolation=2),
    transforms.ToTensor(),
])

test_data = dsets.CIFAR10(root = './data', train = False,
                       transform = test_transform, download=True)

test_gen = torch.utils.data.DataLoader(dataset = test_data,
                                      batch_size = batch_size, 
                                      shuffle = False)
print("Architecture:\n", "\nnet.Multiscale=",net.multiScale)
test(net)

Files already downloaded and verified
Architecture:
 
net.Multiscale= True
RESULTS OF MULTISCALE CNN
Train accuracy of the model: 99.998 %
tensor(50000, device='cuda:0') 50000
Test accuracy of the model: 66.853 %
tensor(6686, device='cuda:0') 10000


###32 -> 24

In [None]:
test_transform = transforms.Compose([
    transforms.CenterCrop(24),
    transforms.Resize(32, interpolation=2),
    transforms.ToTensor(),
])

test_data = dsets.CIFAR10(root = './data', train = False,
                       transform = test_transform, download=True)

test_gen = torch.utils.data.DataLoader(dataset = test_data,
                                      batch_size = batch_size, 
                                      shuffle = False)
print("Architecture:\n", "\nnet.Multiscale=",net.multiScale)
test(net)