Slightly modified version of Meng Dong's PCN, as his version was giving me warnings and for more clarity

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

#%cd /content/drive/My\ Drive/Colab\ Notebooks/RetinaSmartCamera/notebooks

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class PCN(nn.Module):
    class PcConv(nn.Module):
        def __init__(self, inchan, outchan, kernel_size=3, circles=0):
            super().__init__()
            #Number of Feedforward-Feedbackward repetitions
            self.circles = circles
            self.relu = nn.ReLU(inplace=True)
            #Feedforward layer
            self.FFconv = nn.Conv2d(inchan, outchan, kernel_size, stride=1, padding=1, bias=False)
            #Feedbackward layer
            self.FBconv = nn.ConvTranspose2d(outchan, inchan, kernel_size, stride=1, padding=1, bias=False)
            #Recurrent layer
            self.bypass = nn.Conv2d(inchan, outchan, kernel_size=1, stride=1, bias=False)
            #Parameter used to determine the weights of the Feedforwad-Feedbackward layers
            self.alpha = nn.Parameter(torch.zeros(1, outchan, 1, 1))

        def forward(self, x):
            y = self.relu(self.FFconv(x))
            alpha = F.relu(self.alpha[0]+1.0).expand_as(y)
            for _ in range(self.circles):
                y = self.FFconv(self.relu(x - self.FBconv(y)))*alpha + y
            y = y + self.bypass(x)
            return y

    def __init__(self, num_classes, circles=4):
        super().__init__()
        #Predictive Coding blocks
        self.PC_block1 = nn.Sequential(nn.BatchNorm2d(3),
                                       self.PcConv(3,64, circles=circles))
        self.PC_block2 = nn.Sequential(nn.BatchNorm2d(64),
                                       self.PcConv(64,64, circles=circles))
        self.PC_block3 = nn.Sequential(nn.BatchNorm2d(64),
                                       self.PcConv(64, 128, circles=circles),
                                       nn.MaxPool2d(kernel_size=2, stride=2))
        #Final BatchNorm and ReLU
        self.final_bn = nn.Sequential(nn.BatchNorm2d(128),
                                      nn.ReLU(True))
        #Fully connected layer
        self.fc = nn.Linear(128, num_classes)
        
    def forward(self, x):
        x = self.PC_block1(x)
        x = self.PC_block2(x)
        x = self.PC_block3(x)
        x = self.final_bn(x)
        
        #Downsample to flatten last dimensions
        out = F.avg_pool2d(x, x.size(-1))
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        #Return predictions
        return out