In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from google.colab import drive
drive.mount('/content/drive')
from glob import glob
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchsummary import summary
from fastai.metrics import accuracy
from fastai.vision import *
import numpy as np
!pip install fft-conv-pytorch
from fft_conv_pytorch import fft_conv, FFTConv2d

Mounted at /content/drive
Collecting fft-conv-pytorch
  Downloading fft_conv_pytorch-1.1.3-py3-none-any.whl (6.6 kB)
Installing collected packages: fft-conv-pytorch
Successfully installed fft-conv-pytorch-1.1.3


In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
path = "/content/drive/MyDrive/Dataset/grain_dataset"
np.random.seed(41)
data = ImageDataBunch.from_folder(path, train="Train", valid ="Valid", valid_pct=0.2,ds_tfms=get_transforms(), size=(256,256), bs=16, num_workers=4).normalize()

In [None]:
print (data.classes, data.c, len(data.train_ds), len(data.valid_ds))
print (data.train_ds.classes)
print (data.valid_ds.classes)

['anadenanthera', 'arecaceae', 'arrabidaea', 'cecropia', 'chromolaena', 'combretum', 'croton', 'dipteryx', 'eucalipto', 'faramea', 'hyptis', 'mabea', 'matayba', 'mimosa', 'mycria', 'protium', 'qualea', 'schinus', 'senegalia', 'serjania', 'syagrus', 'tridax', 'urochloa'] 23 632 158
['anadenanthera', 'arecaceae', 'arrabidaea', 'cecropia', 'chromolaena', 'combretum', 'croton', 'dipteryx', 'eucalipto', 'faramea', 'hyptis', 'mabea', 'matayba', 'mimosa', 'mycria', 'protium', 'qualea', 'schinus', 'senegalia', 'serjania', 'syagrus', 'tridax', 'urochloa']
['anadenanthera', 'arecaceae', 'arrabidaea', 'cecropia', 'chromolaena', 'combretum', 'croton', 'dipteryx', 'eucalipto', 'faramea', 'hyptis', 'mabea', 'matayba', 'mimosa', 'mycria', 'protium', 'qualea', 'schinus', 'senegalia', 'serjania', 'syagrus', 'tridax', 'urochloa']


With FFT Conv Layers

In [None]:
def fft_conv_block(ni, nf, size=3, stride=1):
    for_pad = lambda s: s if s > 2 else 3
    return nn.Sequential(
        FFTConv2d(ni, nf, kernel_size=size, stride=stride,
                  padding=(for_pad(size) - 1)//2, bias=False), 
        nn.BatchNorm2d(nf),
        nn.LeakyReLU(negative_slope=0.1, inplace=True)  
    )

def conv_block(ni, nf, size=3, stride=1):
    for_pad = lambda s: s if s > 2 else 3
    return nn.Sequential(
        nn.Conv2d(ni, nf, kernel_size=size, stride=stride,
                  padding=(for_pad(size) - 1)//2, bias=False), 
        nn.BatchNorm2d(nf),
        nn.LeakyReLU(negative_slope=0.1, inplace=True)  
    )

def triple_conv(ni, nf):
    return nn.Sequential(
        conv_block(ni, nf),
        conv_block(nf, ni, size=1),  
        conv_block(ni, nf)
    )
def fft_triple_conv(ni, nf,size):
    return nn.Sequential(
        fft_conv_block(ni, nf,size),
        conv_block(nf, ni, size=1),  
        fft_conv_block(ni, nf, size)
    )
def maxpooling():
    return nn.MaxPool2d(2, stride=2)

In [None]:
fft_model4 = nn.Sequential(
    fft_conv_block(3, 8, 21),
    maxpooling(),
    conv_block(8, 16),
    maxpooling(),
    triple_conv(16,32),
    maxpooling(),
    triple_conv(32, 64),
    maxpooling(),
    triple_conv(64, 128),
    maxpooling(),
    triple_conv(128, 256),
    conv_block(256, 128, size=1),
    conv_block(128, 256),
    conv_block(256,23),
    nn.Flatten(),
    nn.Linear(3887, 23)
)
fft_learn4 = Learner(data, fft_model4, loss_func = nn.CrossEntropyLoss(), metrics=accuracy)
print(fft_learn4.summary())
fft_learn4.fit_one_cycle(100, max_lr=5e-3)

Sequential
Layer (type)         Output Shape         Param #    Trainable 
_FFTConv             [8, 256, 256]        10,584     True      
______________________________________________________________________
BatchNorm2d          [8, 256, 256]        16         True      
______________________________________________________________________
LeakyReLU            [8, 256, 256]        0          False     
______________________________________________________________________
MaxPool2d            [8, 128, 128]        0          False     
______________________________________________________________________
Conv2d               [16, 128, 128]       1,152      True      
______________________________________________________________________
BatchNorm2d          [16, 128, 128]       32         True      
______________________________________________________________________
LeakyReLU            [16, 128, 128]       0          False     
___________________________________________________

epoch,train_loss,valid_loss,accuracy,time
0,2.85018,3.640631,0.063291,01:44
1,2.359536,1.751827,0.481013,01:38
2,1.986273,1.270019,0.563291,01:39
3,1.651333,1.340276,0.556962,01:39
4,1.451509,1.914966,0.386076,01:41
5,1.370186,1.782522,0.417722,01:43
6,1.335951,2.162958,0.424051,01:42
7,1.332776,3.056602,0.392405,01:43
8,1.268704,1.992818,0.531646,01:43
9,1.341616,1.719124,0.468354,01:43


In [None]:
probs,targets = fft_learn4.get_preds(ds_type=DatasetType.Valid) 
accuracy(probs,targets)

tensor(0.8924)