In [None]:
!pip install -Uqq fastbook
!pip install -Uq fastai


In [None]:
from fastbook import *

path = untar_data(URLs.MNIST_SAMPLE)

img3_array = [tensor(Image.open(o)) for o in (path/'train'/'3').ls()]
img7_array = [tensor(Image.open(o)) for o in (path/'train'/'7').ls()]

valid_img3_array = [tensor(Image.open(o)) for o in (path/'valid'/'3').ls()]
valid_img7_array = [tensor(Image.open(o)) for o in (path/'valid'/'7').ls()]

stacked_3 = torch.stack(img3_array).float() / 255
stacked_7 = torch.stack(img7_array).float() / 255

valid_stacked_3 = torch.stack(valid_img3_array).float() / 255
valid_stacked_7 = torch.stack(valid_img7_array).float() / 255

In [None]:
train_x = torch.cat([stacked_3, stacked_7]).view(-1, 28*28)
train_y = torch.tensor([1] * len(stacked_3) + [0] * len(stacked_7)).unsqueeze(1)

dset = list(zip(train_x, train_y))

valid_x = torch.cat([valid_stacked_3, valid_stacked_7]).view(-1, 28 * 28)
valid_y = tensor([1] * len(valid_stacked_3) + [0] * len(valid_stacked_7)).unsqueeze(1)

valid_dset = list(zip(valid_x, valid_y))

# DataLoader, pass in a data set (independet, target) and shuffle
dl = DataLoader(dset, batch_size=256)
valid_dl = DataLoader(valid_dset, batch_size=256)

def init_params(size, std=1.0):
  return (torch.randn(size) * std).requires_grad_()

def linear1(xb): return xb @ weights + bias

def mnist_loss(predictions, targets):
  predictions = predictions.sigmoid()
  return torch.where(targets==1, 1-predictions, predictions).mean()

def calc_grad(xb, yb, model):
  preds = model(xb)
  loss = mnist_loss(preds, yb)
  loss.backward()

def train_epoch(model, lr, params):
  for xb, yb in dl:
    calc_grad(xb, yb, model)
    for p in params:
      p.data -= p.grad * lr
      p.grad.zero_()

def batch_accuracy(xb, yb):
  preds = xb.sigmoid()
  correct = (preds > 0.5) == yb
  return correct.float().mean()

def validate_epoch(model):
  accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]
  return round(torch.stack(accs).mean().item(), 4)


weights = init_params((28 * 28), 1)
bias = init_params(1)


lr = 1.
params = weights, bias
train_epoch(linear1, lr, params)
validate_epoch(linear1)

for i in range(20):
  train_epoch(linear1, lr, params)
  print(validate_epoch(linear1), end=' ')


In [None]:
# Proof of concept - calculating gradients for a quadratic function

time = torch.arange(0, 20).float()
speed = (time-9.5)**2

def f(t, params):
  a,b,c = params
  return a*(t**2) + (b*t) + c

def mse(preds, targets): return ((preds - targets)**2).mean()

params = torch.randn(3).float().requires_grad_()

def show_preds(preds, ax=None):
  if ax is None: ax = plt.subplots()[1]
  ax.scatter(time, speed)
  ax.scatter(time, to_np(preds), color='red')
  ax.set_ylim(-300, 300)
  ax.set_xlim(-30, 30)

lr = 1e-5
def apply_step(params, prn=True):
  preds = f(time, params)
  loss = mse(preds, speed)
  loss.backward()
  params.data -= lr * params.grad
  params.grad = None
  if prn: print(loss.item())
  return preds

for i in range(100):
  apply_step(params, False)

optimized_preds = f(time, params)
show_preds(optimized_preds)