In [1]:
from fastai.vision.all import *

In [2]:
path = untar_data(URLs.MNIST_SAMPLE)
Path.BASE_PATH = path
threes_path = (path/'train/3').ls().sorted()
sevens_path = (path/'train/7').ls().sorted()
three_tensors = [tensor(Image.open(o)) for o in threes_path]
seven_tensors = [tensor(Image.open(o)) for o in sevens_path]
stacked_threes = torch.stack(three_tensors).float()/255
stacked_sevens = torch.stack(seven_tensors).float()/255
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)
train_y = tensor([1]*len(threes_path) + [0]*len(sevens_path)).unsqueeze(1)
dset = list(zip(train_x, train_y))

valid_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid/3').ls().sorted()])
valid_3_tens = valid_3_tens.float()/255
valid_7_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid/7').ls().sorted()])
valid_7_tens = valid_7_tens.float()/255
valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x, valid_y))

In [3]:
torch.manual_seed(0)


<torch._C.Generator at 0x22944bda3d0>

In [4]:
dl = DataLoader(dset, batch_size=256)
valid_dl = DataLoader(valid_dset, batch_size= 256)

In [5]:
linear_model = nn.Linear(28*28,1)

In [6]:
w, b = linear_model.parameters()
w.shape, b.shape

(torch.Size([1, 784]), torch.Size([1]))

In [7]:
class BasicOptim:
    def __init__(self, params, lr): self.params, self.lr = list(params), lr

    def step(self, *args, **kwargs):
        for p in self.params: p.data -= p.grad.data * self.lr

    def zero_grad(self, *args, **kwargs):
        for p in self.params: p.grad = None

In [8]:
lr = 1.

In [9]:
opt = BasicOptim(linear_model.parameters(), lr)

In [10]:
def mnist_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets==1, 1-predictions, predictions).mean()

In [11]:
def calc_grad(xb,yb,model):
    preds = model(xb)
    loss = mnist_loss(preds, yb)
    loss.backward()

In [12]:
def train_epoch(model):
    for xb, yb in dl:
        calc_grad(xb, yb, model)
        opt.step()
        opt.zero_grad()

In [13]:
def batch_accuracy(xb, yb):
    preds = xb.sigmoid()
    correct = (preds>0.5) == yb
    return correct.float().mean()

In [14]:
def validate_epoch(model):
    accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]
    return round(torch.stack(accs).mean().item(), 4)

In [15]:
validate_epoch(linear_model)

0.419

In [16]:
def train_model(model, epochs):
    for i in range(epochs):
        print(f"Epoch {i+1}")
        train_epoch(model)
        print(validate_epoch(model), end='\n')

In [17]:
train_model(linear_model, 20)

Epoch 1


0.4932
Epoch 2
0.8843
Epoch 3
0.814
Epoch 4
0.9087
Epoch 5
0.9336
Epoch 6
0.9463
Epoch 7
0.9555
Epoch 8
0.9614
Epoch 9
0.9663
Epoch 10
0.9673
Epoch 11
0.9697
Epoch 12
0.9712
Epoch 13
0.9741
Epoch 14
0.9751
Epoch 15
0.9761
Epoch 16
0.977
Epoch 17
0.9775
Epoch 18
0.9775
Epoch 19
0.9785
Epoch 20
0.9785


this SGD class provided by fastai does the same thing as our BasicOptim

In [18]:
linear_model = nn.Linear(28*28,1)
opt = SGD(linear_model.parameters(), lr)
train_model(linear_model, 20)

Epoch 1
0.4932
Epoch 2
0.8911
Epoch 3
0.812
Epoch 4
0.9082
Epoch 5
0.9321
Epoch 6
0.9438
Epoch 7
0.9551
Epoch 8
0.9614
Epoch 9
0.9648
Epoch 10
0.9668
Epoch 11
0.9692
Epoch 12
0.9721
Epoch 13
0.9731
Epoch 14
0.9746
Epoch 15
0.9761
Epoch 16
0.9765
Epoch 17
0.9775
Epoch 18
0.9775
Epoch 19
0.978
Epoch 20
0.9785


In [19]:
dls = DataLoaders(dl, valid_dl)
type(dls)

fastai.data.core.DataLoaders

In [20]:
learn = Learner(dls,nn.Linear(28*28,1), opt_func=SGD, loss_func=mnist_loss,  metrics=batch_accuracy)
type(learn)

fastai.learner.Learner

In [21]:
learn.fit(10,lr=lr)

epoch,train_loss,valid_loss,batch_accuracy,time
0,0.636848,0.503167,0.495584,00:00
1,0.436319,0.234523,0.791462,00:00
2,0.163384,0.164356,0.851325,00:00
3,0.073413,0.101648,0.914622,00:00
4,0.04019,0.075502,0.934249,00:00
5,0.027157,0.060979,0.948479,00:00
6,0.021747,0.051856,0.957311,00:00
7,0.019298,0.045762,0.962218,00:00
8,0.018026,0.041461,0.965653,00:00
9,0.017245,0.038275,0.966634,00:00


**Adding Nonlinearity**