In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install import-ipynb
%cd /content/drive/My\ Drive/Colab\ Notebooks/RetinaSmartCamera/notebooks

In [None]:
import import_ipynb
import DataLoaders
from DataLoaders import get_dataloader

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data.sampler import SubsetRandomSampler

from torch.utils.data.dataset import Dataset, Subset
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from torchvision.utils import make_grid

from torch.utils.tensorboard import SummaryWriter

import re
import os

Meng Dong's PCNet used for image classification

In [None]:
class Net(nn.Module):
    class PcConvBp(nn.Module):
        def __init__(self, inchan, outchan, kernel_size=3, stride=1, padding=1, circles=0, bias=False):
            super().__init__()
            self.FFconv = nn.Conv2d(inchan, outchan, kernel_size, stride, padding, bias=bias)
            self.FBconv = nn.ConvTranspose2d(outchan, inchan, kernel_size, stride, padding, bias=bias)
            self.b0 = nn.ParameterList([nn.Parameter(torch.zeros(1, outchan, 1, 1))])
            self.relu = nn.ReLU(inplace=True)
            self.circles = circles
            self.bypass = nn.Conv2d(inchan, outchan, kernel_size=1, stride=1, bias=False)

        def forward(self, x):
            y = self.relu(self.FFconv(x))
            b0 = F.relu(self.b0[0]+1.0).expand_as(y)
            for _ in range(self.circles):
                y = self.FFconv(self.relu(x - self.FBconv(y)))*b0 + y
            y = y + self.bypass(x)
            return y

    def __init__(self, num_classes):
        super().__init__()
        self.ics = [3,  64, 64, 128, 128, 256, 256, 512] # input chanels
        self.ocs = [64, 64, 128, 128, 256, 256, 512, 512] # output chanels
        self.maxpool = [False, False, True, False, True, False, False, False] # downsample flag

        n = 3
        self.ics = self.ics[:n]
        self.ocs = self.ocs[:n]
        self.maxpool = self.maxpool[:n]

        self.circles = 5
        self.nlays = len(self.ics)

        # construct PC layers
        self.PcConvs = nn.ModuleList([self.PcConvBp(self.ics[i], self.ocs[i], circles=self.circles) for i in range(self.nlays)])
        self.BNs = nn.ModuleList([nn.BatchNorm2d(self.ics[i]) for i in range(self.nlays)])
        # Linear layer
        self.linear = nn.Linear(self.ocs[-1], num_classes)
        self.maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU(inplace=True)
        self.BNend = nn.BatchNorm2d(self.ocs[-1])

    def forward(self, x):
        for i in range(self.nlays):
            x = self.BNs[i](x)
            x = self.PcConvs[i](x)  # ReLU + Conv
            if self.maxpool[i]:
                x = self.maxpool2d(x)

        # classifier
        out = F.avg_pool2d(self.relu(self.BNend(x)), x.size(-1))
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out
model = Net(20)
print(model)

In [None]:
#My implementation of compressing PCNet
class PcConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, loops=0, kernel_size=3, stride=1):
        super(PcConvBlock, self).__init__()
        encode_block = True if in_channel<out_channel else False
        
        if encode_block:
            self.feedforward = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=1, bias=False)
            self.feedbackward = nn.ConvTranspose2d(out_channel, in_channel, kernel_size, stride=stride, padding=1, output_padding=stride-1, bias=False)
            self.bypass = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, padding=0, bias=False)
        else:
            self.feedforward = nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride=stride, output_padding=stride-1, padding=1, bias=False)
            self.feedbackward = nn.Conv2d(out_channel, in_channel, kernel_size, stride=stride, padding=1, bias=False)
            self.bypass = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=1, stride=stride, output_padding=stride-1, padding=0, bias=False)
        
        self.alpha = nn.Parameter(torch.zeros((1, out_channel, 1, 1)))
        self.relu = nn.ReLU(True)
        self.loops = loops
        self.batchnorm = nn.BatchNorm2d(out_channel)

    def forward(self, x):
        y = self.relu(self.feedforward(x))
        alpha = self.relu(self.alpha[0]+1.0).expand_as(y)
        for _ in range(self.loops):
            y = self.feedforward(self.relu(x - self.feedbackward(y)))*alpha + y
        y = y + self.bypass(x)
        return self.batchnorm(y)

x = torch.ones((8, 3, 512, 512))
block = PcConvBlock(3, 32, loops=3, kernel_size=3, stride=2)

#with torch.no_grad():
#    out = block(x)

class PCNet(nn.Module):
    def __init__(self, loops=3):
        super(PCNet, self).__init__()
        self.encode1 = PcConvBlock(3, 12, loops=loops, stride=2)
        self.encode2 = PcConvBlock(12, 32, loops=loops, stride=2)
        self.encode3 = PcConvBlock(32, 64, loops=loops, stride=2)
        self.encode4 = PcConvBlock(64, 128, loops=loops, stride=2)

        self.encode_bottleneck = nn.Sequential(nn.ReLU(True),
                                               nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=False),
                                               nn.BatchNorm2d(128))
        self.decode_bottleneck = nn.Sequential(nn.ReLU(True),
                                               nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
                                               nn.BatchNorm2d(128))
        
        self.decode4 = PcConvBlock(128, 64, loops=loops, stride=2)
        self.decode3 = PcConvBlock(64, 32, loops=loops, stride=2)
        self.decode2 = PcConvBlock(32, 12, loops=loops, stride=2)
        self.decode1 = PcConvBlock(12, 3, loops=loops, stride=2)
    
    def forward(self, x):
        x = self.encode1(x)
        x = self.encode2(x)
        x = self.encode3(x)
        x = self.encode4(x)

        x = self.encode_bottleneck(x)
        x = self.decode_bottleneck(x)
        
        x = self.decode4(x)
        x = self.decode3(x)
        x = self.decode2(x)
        x = self.decode1(x)

        return x
test = PCNet(loops=3)
x = torch.ones((8, 3, 512, 512))
with torch.no_grad():
    print(test(x).shape)