In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
import torch.nn.functional as F

from dstorch.utils import random_weight_init
from octconv import OctConv2d, OctReLU, OctMaxPool2d

# 1. Loading FashionMNIST Dataset

In [2]:
batch_size = 512

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.FashionMNIST(root='./data', 
                                      train=True,
                                      transform=transform,
                                      download=True
                                      )

test_dataset = datasets.FashionMNIST(root='./data', 
                                     train=False, 
                                     transform=transform
                                     )

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size,            
                                           shuffle=True
                                           )

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False
                                          )

# 2. Building Model

In [3]:
class OctCNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.convs = nn.Sequential(OctConv2d('first', in_channels=1, out_channels=32, kernel_size=3),
                                   OctReLU(),
                                   OctConv2d('regular', in_channels=32, out_channels=64, kernel_size=3),
                                   OctReLU(),
                                   OctConv2d('regular', in_channels=64, out_channels=128, kernel_size=3),
                                   OctReLU(),
                                   OctMaxPool2d(2),
                                   OctConv2d('regular', in_channels=128, out_channels=128, kernel_size=3),
                                   OctReLU(),
                                   OctConv2d('last', in_channels=128, out_channels=128, kernel_size=3),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2),
                                  )
        self.fc = nn.Sequential(nn.Linear(6272, 256),
                                nn.Dropout(0.5),
                                nn.Linear(256, 10)
                                )
    
    
    def forward(self, x):
        
        x = self.convs(x)
        x = x.view(-1, x.size(1) * x.size(2) * x.size(3))
        x = self.fc(x)
        
        return x

# 3. Instantiate Model and Loss Criterion

In [4]:
model = OctCNN()

if torch.cuda.is_available():
    model = model.cuda()
    model = nn.DataParallel(model)
    
criterion = nn.CrossEntropyLoss()

# 4. Start Training 

In [5]:
num_epochs = 30
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

num_iter = 0

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):

        if torch.cuda.is_available():
            images, labels = images.cuda(), labels.cuda()
            
        images = Variable(images)
        labels = Variable(labels)
        
        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()
        
        # Forward pass to get output/logits
        outputs = model(images)
        
        # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, labels)
        
        # Getting gradients w.r.t. parameters
        loss.backward()
        
        # Updating parameters
        optimizer.step()
        
        num_iter += 1
        
        if num_iter % 500 == 0:
            # Calculate Accuracy         
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
                if torch.cuda.is_available():
                    images, labels = images.cuda(), labels.cuda()
                    
                images = Variable(images)
                
                # Forward pass only to get logits/output
                outputs = model(images)
                
                # Get predictions from the maximum value
                _, predicted = torch.max(outputs.data, 1)
                
                # Total number of labels
                total += labels.size(0)
                correct += (predicted == labels).sum().detach().cpu()
            
            accuracy = 100 * (correct.item() / total)
            
            # Print Loss
            print('Iteration: {} Loss: {}. Accuracy: {:4f}'.format(num_iter, loss.item(), accuracy))

Iteration: 500 Loss: 0.362911581993103. Accuracy: 84.900000
Iteration: 1000 Loss: 0.31636983156204224. Accuracy: 87.040000
Iteration: 1500 Loss: 0.30958276987075806. Accuracy: 89.330000
Iteration: 2000 Loss: 0.21788382530212402. Accuracy: 89.420000
Iteration: 2500 Loss: 0.18522046506404877. Accuracy: 90.270000
Iteration: 3000 Loss: 0.1775389015674591. Accuracy: 90.890000
Iteration: 3500 Loss: 0.15970204770565033. Accuracy: 91.020000
