In [7]:
import fastbook
fastbook.setup_book()
from fastai.vision.all import *
from fastbook import *
matplotlib.rc('image', cmap='Greys')

In [8]:
path = untar_data(URLs.MNIST_SAMPLE)

In [9]:
threes = (path/'train'/'3').ls().sorted()
sevens = (path/'train'/'7').ls().sorted()
three_tensors = [tensor(Image.open(i)) for i in threes]
seven_tensors = [tensor(Image.open(i)) for i in sevens]
stacked_sevens = torch.stack(seven_tensors).float()/255
stacked_threes = torch.stack(three_tensors).float()/255

In [10]:
valid_3_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'valid'/'3').ls()])
valid_3_tens = valid_3_tens.float()/255

valid_7_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'valid'/'7').ls()])
valid_7_tens = valid_7_tens.float()/255

In [11]:
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)
train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)
dset = list(zip(train_x,train_y))

valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x,valid_y))

In [12]:
def init_params(size, std=1.0): return (torch.randn(size)*std).requires_grad_()

In [13]:
weights = init_params((28*28,1))
bias = init_params(1)

In [14]:
def linear1(xb): return xb@weights + bias

In [15]:
def mnist_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets==1, 1-predictions, predictions).mean()

In [18]:
def train():
    cur_weights = init_params((28*28,1))
    cur_bias = init_params(1)
    
    for i in range(1000): 
        preds = train_x@cur_weights + bias
        loss = mnist_loss(preds, train_y)
        loss.backward()
        cur_weights.data -= 0.5 * cur_weights.grad.data
    return (loss, {"weights": cur_weights, "bias:": cur_bias})
        

In [19]:
# BATCHES

In [71]:
dl = DataLoader(dset, batch_size=256)
valid_dl = DataLoader(valid_dset, batch_size=256)

In [72]:
def calc_grad(xb, yb, model):
    preds = model(xb)
    loss = mnist_loss(preds, yb)
    loss.backward()

In [79]:
def train_epoch(model, lr, params):
    for xb, yb in dl:
        calc_grad(xb, yb, model)
        for p in params:
            p.data -= p.grad*lr
            p.grad.zero_()

In [80]:
def batch_accuracy(xb, yb):
    preds = xb.sigmoid()
    correct = (preds>0.5) == yb
    return correct.float().mean()

In [81]:
batch_accuracy(linear1(batch), train_y[:4])

tensor(0.)

In [82]:
def validate_epoch(model):
    accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]
    return round(torch.stack(accs).mean().item(), 4)

In [83]:
lr = 1
params = weights, bias

In [84]:
for i in range(20):
    train_epoch(linear1, lr, params)
    print(validate_epoch(linear1), end=' ')

0.7304 0.851 0.9013 0.9286 0.9389 0.9447 0.9526 0.955 0.9589 0.9613 0.9623 0.9623 0.9633 0.9638 0.9647 0.9662 0.9667 0.9672 0.9677 0.9677 