# Octave Convolution
My try with pyTorch with a case study of Octave Convolution from https://arxiv.org/pdf/1904.05049.pdf

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Dataset Generation

In [2]:
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.optim as optim
from tqdm import tqdm
import numpy as np

transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

trainset = torchvision.datasets.CIFAR10(root='/home/ubuntu/research/data', train=True,
                            download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset,batch_size=512,shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='/home/ubuntu/research/data', train=False,
                            download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset,batch_size=1024,shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()

dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

Files already downloaded and verified
Files already downloaded and verified


<Figure size 640x480 with 1 Axes>

  cat  frog   cat plane


In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,6,5,padding=2)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16,5,padding=2)
        self.fc1 = nn.Linear(16*8*8,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1,16*8*8)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x 

Lets create a class for our Octave convolution. We shall create it so that it can easily replace vanilla convolution. alpha is the portion of the channels that are low frequency.

For all convolution layers (except first and last) alpha_in and alpha_out are 0.5
For first convolution layer: alpha_in = 0 and alpha_out = 0.5
For last convolution layer: alpha_in = 0.5 and alpha_out = 0.0
We want to pack the output (hf and lf components) such that they hf and lf all of the same size. during forward() we shall unpack them to their respective sizes. This is so that we can use OctConv as is with other pyTorch modules like Relu, pool etc.

In [4]:
class OctConv(nn.Module): 
    def __init__(self, ch_in, ch_out, kernel_size, stride=1, alphas=[0.5,0.5], padding=0): 
        super(OctConv, self).__init__()

        # get layer parameters 
        self.alpha_in, self.alpha_out = alphas
        assert 0 <= self.alpha_in <= 1 and 0 <= self.alpha_in <= 1, "Alphas must be in interval [0, 1]"
        
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = (kernel_size - stride ) // 2 ## padding
        
        # Calculate the exact number of high/low frequency channels 
        self.ch_in_lf = int(self.alpha_in*ch_in)
        self.ch_in_hf = ch_in - self.ch_in_lf
        self.ch_out_lf = int(self.alpha_out*ch_out) 
        self.ch_out_hf = ch_out - self.ch_out_lf

        # Create convolutional and other modules necessary
        self.hasLtoL = self.hasLtoH = self.hasHtoL = self.hasHtoH = False
        if (self.ch_in_lf and self.ch_out_lf):    
            self.hasLtoL = True
            self.conv_LtoL = nn.Conv2d(self.ch_in_lf, self.ch_out_lf, self.kernel_size, padding=self.padding)
        if (self.ch_in_lf and self.ch_out_hf): 
            self.hasLtoH = True
            self.conv_LtoH = nn.Conv2d(self.ch_in_lf, self.ch_out_hf, self.kernel_size, padding=self.padding)
        if (self.ch_in_hf and self.ch_out_lf):
            self.hasHtoL = True
            self.conv_HtoL = nn.Conv2d(self.ch_in_hf, self.ch_out_lf, self.kernel_size, padding=self.padding)
        if (self.ch_in_hf and self.ch_out_hf):
            self.hasHtoH = True
            self.conv_HtoH = nn.Conv2d(self.ch_in_hf, self.ch_out_hf, self.kernel_size, padding=self.padding)
        self.avg_pool  = nn.AvgPool2d(2,2)
        
    def forward(self, input): 
        
        # Split input into high frequency and low frequency components
        fmap_w = input.shape[-1]
        fmap_h = input.shape[-2]
        # We resize the high freqency components to the same size as the low frequency component when 
        # sending out as output. So when bringing in as input, we want to reshape it to have the original  
        # size as the intended high frequnecy channel (if any high frequency component is available). 
        input_hf = input
        if (self.ch_in_lf):
            input_hf = input[:,:self.ch_in_hf*4,:,:].reshape(-1,self.ch_in_hf,fmap_h*2,fmap_w*2)
            input_lf = input[:,self.ch_in_hf*4:,:,:]    
        
        # Create all conditional branches 
        LtoH = HtoH = LtoL = HtoL = 0.
        if (self.hasLtoL):
            LtoL = self.conv_LtoL(input_lf)
        if (self.hasHtoH):
            HtoH = self.conv_HtoH(input_hf)
            op_h, op_w = HtoH.shape[-2]//2, HtoH.shape[-1]//2
            HtoH = HtoH.reshape(-1, self.ch_out_hf*4, op_h, op_w)
        if (self.hasLtoH):
            LtoH = F.interpolate(self.conv_LtoH(input_lf), scale_factor=2, mode='bilinear')
            op_h, op_w = LtoH.shape[-2]//2, LtoH.shape[-1]//2
            LtoH = LtoH.reshape(-1, self.ch_out_hf*4, op_h, op_w)
        if (self.hasHtoL):
            HtoL = self.avg_pool(self.conv_HtoL(input_hf))
        
        # Elementwise addition of high and low freq branches to get the output
        out_hf = LtoH + HtoH
        out_lf = LtoL + HtoL
        
        if (self.ch_out_lf == 0):
            return out_hf
        if (self.ch_out_hf == 0):
            return out_lf
        op = torch.cat([out_hf,out_lf],dim=1)
        return op

Lets create our network using our new convolution

In [5]:
class Net_OctConv(nn.Module):
    def __init__(self):
        super(Net_OctConv, self).__init__()
        self.conv1 = OctConv(3,6,5,alphas=[0.,0.5])
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = OctConv(6,16,5,alphas=[0.5,0.])
        self.fc1 = nn.Linear(16*8*8,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self, input):
        x = self.pool(F.relu(self.conv1(input)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1,16*8*8)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        #print("\tIn Model: input size", input.size(),
        #      "output size", x.size())
        return x 

In [6]:
net = Net_OctConv()
#net = Net()
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    net = nn.DataParallel(net)
net.to(device)

Let's use 4 GPUs!


DataParallel(
  (module): Net_OctConv(
    (conv1): OctConv(
      (conv_HtoL): Conv2d(3, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (conv_HtoH): Conv2d(3, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (avg_pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): OctConv(
      (conv_LtoH): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (conv_HtoH): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (avg_pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
    (fc1): Linear(in_features=1024, out_features=120, bias=True)
    (fc2): Linear(in_features=120, out_features=84, bias=True)
    (fc3): Linear(in_features=84, out_features=10, bias=True)
  )
)

In [7]:
from torchsummary import summary
summary(net, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 3, 32, 32]             228
            Conv2d-2            [-1, 3, 32, 32]             228
            Conv2d-3            [-1, 3, 32, 32]             228
            Conv2d-4            [-1, 3, 32, 32]             228
         AvgPool2d-5            [-1, 3, 16, 16]               0
         AvgPool2d-6            [-1, 3, 16, 16]               0
           OctConv-7           [-1, 15, 16, 16]               0
           OctConv-8           [-1, 15, 16, 16]               0
         MaxPool2d-9             [-1, 15, 8, 8]               0
        MaxPool2d-10             [-1, 15, 8, 8]               0
           Conv2d-11           [-1, 16, 16, 16]           1,216
           Conv2d-12           [-1, 16, 16, 16]           1,216
           Conv2d-13             [-1, 16, 8, 8]           1,216
           Conv2d-14             [-1, 1

  "See the documentation of nn.Upsample for details.".format(mode))


### Train and Validate 

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    
for epoch in tqdm(range(20)):
    running_loss = 0.0
    for i, data in enumerate(trainloader,0):
        inputs_cpu, labels_cpu = data
        inputs, labels = inputs_cpu.to(device), labels_cpu.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        #print("Outside: input size", inputs.size(),
        #  "output_size", outputs.size())
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' % (epoch +1, i+1, running_loss/2000))
            running_loss = 0

print('Finished Training')

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        inputs_cpu, labels_cpu = data
        inputs, labels = inputs_cpu.to(device), labels_cpu.to(device)
        outputs = net(inputs)
        _,predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print ('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))	

  "See the documentation of nn.Upsample for details.".format(mode))
 55%|█████▌    | 11/20 [01:16<00:58,  6.53s/it]