In [1]:
import os
import torch
import numpy as np
from tqdm import tqdm


import torch.nn.functional as F
from fastai.vision.all import URLs, untar_data

path = untar_data(URLs.MNIST)
(path).ls()

(#2) [Path('/home/akzsh/.fastai/data/mnist_png/training'),Path('/home/akzsh/.fastai/data/mnist_png/testing')]

In [2]:
class DataLoaders:
    def __init__(
            self, path, train: str, valid:str,
            batch_size:int = 64,
            shuffle: bool = True,
        ) -> None:
        self.path = path
        self.train_path = os.path.join(path, train)
        self.valid_path = os.path.join(path, valid)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.X_train, self.Y_train = self.load_from_folder(directory=self.train_path)
        self.X_valid, self.Y_valid = self.load_from_folder(directory=self.valid_path)

    def __len__(self):
        """Returns total number of data in single batch."""
        if not self.X_train: raise("Data not loaded yet")
        return len(self.X_train) // self.batch_size

    def __str__(self):
        train_data = list(zip(self.x_train, self.y_train))
        return f"{train_data}"

    def load_from_folder(self, directory):
        from PIL import Image
        classes = os.listdir(path=directory)
        X_data, Y_data = [], []

        with tqdm(classes, desc='Classes', unit='class') as pbar_classes:
            for idx, cls_name in enumerate(pbar_classes):
                cls_path = os.path.join(directory, cls_name) # mnist/training/5
                imgs = os.listdir(cls_path) # number of images
                for img_name in imgs:
                    img_path = os.path.join(cls_path, img_name)
                    with Image.open(img_path) as img:
                        img = img.resize((28,28))
                        img = np.array(img, dtype=np.uint8)
                        X_data.append(img) # add image
                        Y_data.append(idx) # add label
            pbar_classes.update(1)
        X_data, Y_data = np.array(X_data), np.array(Y_data)
        if self.shuffle:
            indices = np.arange(len(X_data))
            np.random.shuffle(indices)
            X_data, Y_data = X_data[indices], Y_data[indices]
        return X_data, Y_data

    def get_validation_data(self):
        if not self.x_val and not self.y_val: return None
        return list(zip(self.x_val, self.y_val))


dls = DataLoaders(path=path, train='training', valid='testing')

Classes: 100%|██████████| 10/10 [00:09<00:00,  1.08class/s]
Classes: 100%|██████████| 10/10 [00:01<00:00,  6.72class/s]


In [3]:
# loss function
def mse(preds, targs): return ((preds-targs)**2).mean()

In [4]:
# func to init weights
def init_params(size): return torch.randn(size).float().requires_grad_()

In [5]:
x = torch.tensor(dls.X_train).float()/255
y = torch.tensor(dls.Y_train).long()

x_valid = torch.tensor(dls.X_valid).float()/255
y_valid = torch.tensor(dls.Y_valid).long()

x = x.view(x.size(0), -1)
x_valid = x_valid.view(x_valid.size(0), -1)

x.shape, y.shape, x_valid.shape, y_valid.shape

(torch.Size([60000, 784]),
 torch.Size([60000]),
 torch.Size([10000, 784]),
 torch.Size([10000]))

In [81]:
w1 = init_params((28*28, 10))
b1 = init_params((60000,10))

w1.shape,b1.shape

(torch.Size([784, 10]), torch.Size([60000, 10]))

In [82]:
def linear(xb): return xb@w1 + b1
loss_func = torch.nn.CrossEntropyLoss(reduction='none')

In [144]:
preds = linear(x)
preds = F.sigmoid(preds)
preds = F.softmax(preds, dim=-1)
preds.shape

torch.Size([60000, 10])

In [145]:
accuracy = (torch.argmax(preds, dim=-1) == y).float().mean()
accuracy

tensor(0.0963)

In [146]:
loss = loss_func(preds, y.long())
loss = loss.mean()
loss.backward()

In [147]:
lr = 1e-2
w1.data += w1.grad.data * lr
b1.data += b1.grad.data * lr
w1.grad.zero_(),b1.grad.zero_()
w1

tensor([[-1.6894,  1.2834,  0.0073,  ..., -1.3896, -0.9225, -1.4367],
        [-0.2047,  1.5760,  0.5729,  ..., -1.1914, -0.0602, -1.3356],
        [-0.8234, -0.1500, -0.2901,  ...,  0.5081, -0.2496,  1.3350],
        ...,
        [ 0.5122, -0.1415, -0.5266,  ..., -0.5527, -0.2687,  1.1023],
        [-0.4032, -0.7966, -0.4806,  ...,  0.5505,  1.0336,  0.3005],
        [-0.8680,  1.1277, -1.4253,  ...,  0.2940, -0.4425, -0.3616]],
       requires_grad=True)