In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import numpy as np 
import matplotlib.pyplot as plt

In [2]:
import fastai
from fastai.vision.all import *

In [3]:
train_ds = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_ds = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

In [4]:
classes = train_ds.classes
# classes

In [5]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=8, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=8)

In [6]:
from fastai.data.core import DataLoaders

In [7]:
dls = DataLoaders(train_dl, test_dl)

In [8]:
import torch.nn.functional as F

In [9]:
#Define CNN Model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3) # nn.Conv2d(1, 16, 3, padding='same') update pytorch version to 1.11 or higher
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.linear1 = nn.Linear(32*5*5, 64) # 5x5 is the output of the pooling layer
        self.linear2 = nn.Linear(64, 10)
        
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 32*5*5) # flatten the output of the convolutional layer, size= 32, 800
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x
        

OptimWrapper is  wrapper class for existing PyTorch optimizers

In [10]:
from fastai.optimizer import OptimWrapper
from functools import partial

In [11]:
opt_func = partial(OptimWrapper, opt=torch.optim.Adam)

In [12]:
import fastai.callback.schedule

from fastai.metrics import accuracy

In [13]:
learn = Learner(dls, CNN(), loss_func=nn.CrossEntropyLoss(), opt_func=opt_func, metrics=accuracy)

In [14]:
learn.fit_one_cycle(n_epoch=2, lr_max=1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,0.107296,0.101471,0.9722,04:31
1,0.04113,0.045561,0.9872,03:32
