# OctaveConv

- https://github.com/ThoroughImages/OctConv/blob/master/demo.ipynb

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import DataLoader

from octconv import OctConv2d, OctReLU, OctMaxPool2d

## 1. Loading FashionMNIST Dataset

In [2]:
batch_size = 512

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.FashionMNIST(root='./data',
                                      train=True,
                                      transform=transform,
                                      download=True)

test_dataset = datasets.FashionMNIST(root='./data',
                                     train=False,
                                     transform=transform)

In [3]:
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size, 
                         shuffle=False)

In [4]:
# MNIST Dataset
len(train_loader)

118

## 2. Building Model

In [5]:
class OctCNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.convs = nn.Sequential(OctConv2d('first', in_channels=1, out_channels=32, kernel_size=3),
                                   OctReLU(),
                                   OctConv2d('regular', in_channels=32, out_channels=64, kernel_size=3),
                                   OctReLU(),
                                   OctConv2d('regular', in_channels=64, out_channels=128, kernel_size=3),
                                   OctReLU(),
                                   OctMaxPool2d(2),
                                   OctConv2d('regular', in_channels=128, out_channels=128, kernel_size=3),
                                   OctReLU(),
                                   OctConv2d('last', in_channels=128, out_channels=128, kernel_size=3),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2),
                                  )
        self.fc = nn.Sequential(nn.Linear(6272, 256),
                                nn.Dropout(0.5),
                                nn.Linear(256, 10)
                                )
    
    
    def forward(self, x):  
        x = self.convs(x)
        x = x.view(-1, x.size(1) * x.size(2) * x.size(3))
        x = self.fc(x)
        return x

In [6]:
model = OctCNN()
if torch.cuda.is_available():
    model = model.cuda()
    model = nn.DataParallel(model)
    
criterion = nn.CrossEntropyLoss()

In [None]:
num_epochs = 30
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

num_iter = 0

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):

        if torch.cuda.is_available():
            images, labels = images.cuda(), labels.cuda()
        
        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()
        
        # Forward pass to get output/logits
        outputs = model(images)
        
        # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, labels)
        
        # Getting gradients w.r.t. parameters
        loss.backward()
        
        # Updating parameters
        optimizer.step()
        
        num_iter += 1
        
        if num_iter % 20 == 0:
            # Calculate Accuracy         
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
                if torch.cuda.is_available():
                    images, labels = images.cuda(), labels.cuda()
                    
                images = Variable(images)
                
                # Forward pass only to get logits/output
                outputs = model(images)
                
                # Get predictions from the maximum value
                _, predicted = torch.max(outputs.data, 1)
                
                # Total number of labels
                total += labels.size(0)
                correct += (predicted == labels).sum().detach().cpu()
            
            accuracy = 100 * (correct.item() / total)
            
            # Print Loss
            print('Iteration: {} Loss: {}. Accuracy: {:4f}'.format(num_iter, loss.item(), accuracy))

Iteration: 20 Loss: 0.4314352571964264. Accuracy: 84.040000
Iteration: 40 Loss: 0.4153634309768677. Accuracy: 84.680000
Iteration: 60 Loss: 0.4263860881328583. Accuracy: 83.880000
Iteration: 80 Loss: 0.46227744221687317. Accuracy: 84.760000
Iteration: 100 Loss: 0.5090758204460144. Accuracy: 84.670000
Iteration: 120 Loss: 0.38190385699272156. Accuracy: 84.940000
Iteration: 140 Loss: 0.3928123116493225. Accuracy: 85.410000
Iteration: 160 Loss: 0.36810436844825745. Accuracy: 85.840000
Iteration: 180 Loss: 0.393439918756485. Accuracy: 85.660000
Iteration: 200 Loss: 0.3972342610359192. Accuracy: 86.050000
Iteration: 220 Loss: 0.37476593255996704. Accuracy: 85.910000
Iteration: 240 Loss: 0.3133814334869385. Accuracy: 85.850000
Iteration: 260 Loss: 0.3304150700569153. Accuracy: 86.060000
Iteration: 280 Loss: 0.3528797924518585. Accuracy: 86.230000
Iteration: 300 Loss: 0.30910128355026245. Accuracy: 86.170000
Iteration: 320 Loss: 0.34352603554725647. Accuracy: 86.510000
Iteration: 340 Loss: 0.