<a href="https://colab.research.google.com/github/animefan380/fastai_practice/blob/main/Copy_of_04_mnist_basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
#hide
!pip install -Uqq fastbook
import fastbook
fastbook.setup_book()

[K     |████████████████████████████████| 727kB 5.8MB/s 
[K     |████████████████████████████████| 194kB 10.3MB/s 
[K     |████████████████████████████████| 51kB 5.4MB/s 
[K     |████████████████████████████████| 1.1MB 13.2MB/s 
[K     |████████████████████████████████| 61kB 6.7MB/s 
[?25hMounted at /content/gdrive


In [3]:
#hide
from fastai.vision.all import *
from fastbook import *

matplotlib.rc('image', cmap='Greys')

[[chapter_mnist_basics]]

# Under the Hood: Training a Digit Classifier

In [4]:
path = untar_data(URLs.MNIST)

In [5]:
#hide
Path.BASE_PATH = path

In [7]:
zeroes = (path/'training'/'0').ls().sorted()
ones = (path/'training'/'1').ls().sorted()
twos = (path/'training'/'2').ls().sorted()
threes = (path/'training'/'3').ls().sorted()
fours = (path/'training'/'4').ls().sorted()
fives = (path/'training'/'5').ls().sorted()
sixes = (path/'training'/'6').ls().sorted()
sevens = (path/'training'/'7').ls().sorted()
eights = (path/'training'/'8').ls().sorted()
nines = (path/'training'/'9').ls().sorted()



## First Try: Pixel Similarity

In [8]:
zero_tensors = [tensor(Image.open(o)) for o in zeroes]
one_tensors = [tensor(Image.open(o)) for o in ones]
two_tensors = [tensor(Image.open(o)) for o in twos]
three_tensors = [tensor(Image.open(o)) for o in threes]
four_tensors = [tensor(Image.open(o)) for o in fours]
five_tensors = [tensor(Image.open(o)) for o in fives]
six_tensors = [tensor(Image.open(o)) for o in sixes]
seven_tensors = [tensor(Image.open(o)) for o in sevens]
eight_tensors = [tensor(Image.open(o)) for o in sevens]
nine_tensors = [tensor(Image.open(o)) for o in nines]


len(three_tensors),len(seven_tensors)

(6131, 6265)

In [9]:
stacked_zeroes = torch.stack(zero_tensors).float()/255
stacked_ones = torch.stack(one_tensors).float()/255
stacked_twos= torch.stack(two_tensors).float()/255
stacked_threes = torch.stack(three_tensors).float()/255
stacked_fours = torch.stack(four_tensors).float()/255
stacked_fives = torch.stack(five_tensors).float()/255
stacked_sixes = torch.stack(six_tensors).float()/255
stacked_sevens = torch.stack(seven_tensors).float()/255
stacked_eights = torch.stack(eight_tensors).float()/255
stacked_nines = torch.stack(nine_tensors).float()/255


stacked_nines.shape

torch.Size([5949, 28, 28])

In [11]:
valid_0_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'testing'/'0').ls()]).float()/255
valid_1_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'testing'/'1').ls()]).float()/255
valid_2_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'testing'/'2').ls()]).float()/255
valid_3_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'testing'/'3').ls()]).float()/255
valid_4_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'testing'/'4').ls()]).float()/255
valid_5_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'testing'/'5').ls()]).float()/255
valid_6_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'testing'/'6').ls()]).float()/255
valid_6_tens = valid_6_tens.float()/255

valid_7_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'testing'/'7').ls()]).float()/255

valid_8_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'testing'/'8').ls()]).float()/255

valid_9_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'testing'/'9').ls()]).float()/255
valid_9_tens = valid_9_tens.float()/255

valid_6_tens.shape,valid_9_tens.shape

(torch.Size([958, 28, 28]), torch.Size([1009, 28, 28]))

In [13]:
def mnist_distance(a,b): return (a-b).abs().mean((-1,-2))

In [14]:
def is_6(x): return mnist_distance(x,mean6) < mnist_distance(x,mean9)

## The MNIST Loss Function

In [15]:
train_x = torch.cat([stacked_zeroes,
                     stacked_ones,
                     stacked_twos,
                     stacked_threes,
                     stacked_fours,
                     stacked_fives,
                     stacked_sixes,
                     stacked_sevens,
                     stacked_eights,
                     stacked_nines]).view(-1, 28*28)

In [16]:
train_y = tensor([0]*len(zeroes) +
                [1]*len(ones) + 
                [2]*len(twos) + 
                [3]*len(threes) +
                 [4]*len(fours) +
                 [5]*len(fives) +
                 [6]*len(sixes) +
                 [7]*len(sevens) +
                 [8]*len(eights) +
                 [9]*len(nines)).unsqueeze(1)
train_x.shape,train_y.shape

(torch.Size([60414, 784]), torch.Size([60000, 1]))

In [17]:
dset = list(zip(train_x,train_y))
x,y = dset[0]
x.shape,y

(torch.Size([784]), tensor([0]))

In [18]:
valid_x = torch.cat([valid_0_tens,
                     valid_1_tens,
                     valid_2_tens,
                     valid_4_tens,
                     valid_5_tens,
                     valid_6_tens,
                     valid_7_tens,
                     valid_8_tens,
                     valid_9_tens,
                     ]).view(-1, 28*28)
valid_y = tensor([0]*len(valid_0_tens) + 
                 [1]*len(valid_1_tens) +
                 [2]*len(valid_2_tens) + 
                 [3]*len(valid_3_tens) + 
                 [4]*len(valid_4_tens) + 
                 [5]*len(valid_5_tens) + 
                 [6]*len(valid_6_tens) + 
                 [7]*len(valid_7_tens) + 
                 [8]*len(valid_8_tens) + 
                 [9]*len(valid_9_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x,valid_y))

In [19]:
valid_x.shape,valid_y.shape

(torch.Size([8990, 784]), torch.Size([10000, 1]))

Now we need an (initially random) weight for every pixel (this is the *initialize* step in our seven-step process):

In [20]:
def init_params(size, std=1.0): return (torch.randn(size)*std).requires_grad_()

In [21]:
weights = init_params((28*28,1))

In [22]:
bias = init_params(1)

We can now calculate a prediction for one image:

In [23]:
(train_x[0]*weights.T).sum() + bias

tensor([-9.2751], grad_fn=<AddBackward0>)

In [24]:
def linear1(xb): return xb@weights + bias
preds = linear1(train_x)
preds

tensor([[ -9.2751],
        [-23.0076],
        [-15.8010],
        ...,
        [ -8.8975],
        [-10.6736],
        [ -6.4075]], grad_fn=<AddBackward0>)

Let's check our accuracy. To decide if an output represents a 3 or a 7, we can just check whether it's greater than 0, so our accuracy for each item can be calculated (using broadcasting, so no loops!) with:

In [26]:
corrects = (preds).float() == train_y
corrects

RuntimeError: ignored