In [1]:
!git clone https://github.com/Jow1e/jupy.git
%cd jupy

Cloning into 'jupy'...
remote: Enumerating objects: 72, done.[K
remote: Counting objects: 100% (72/72), done.[K
remote: Compressing objects: 100% (71/71), done.[K
remote: Total 72 (delta 27), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (72/72), done.
/content/jupy


In [2]:
import numpy as np
import matplotlib.pyplot as plt
import jupy as jp

from torchvision.datasets import MNIST


np.random.seed(42)


def loader(X, Y, batch_size, shuffle=True):
	n = X.shape[0]
	indices = np.arange(n)
	
	if shuffle:
		np.random.shuffle(indices)
	
	for start in range(0, n, batch_size):
		end = min(start + batch_size, n)
		
		if end < start + batch_size:
			break
		
		batch_idx = indices[start:end]
		yield jp.Tensor(X[batch_idx], dtype=float), jp.Tensor(Y[batch_idx], dtype=int)


def preprocess(dataset):
	X = np.array([np.asarray(dataset[i][0]) for i in range(len(dataset))])
	X = X.reshape((-1, 784)).astype(float) / 255
	Y = np.array([dataset[i][1] for i in range(len(dataset))])
	return X, Y

n_epochs = 15
batch_size = 128

In [3]:
train_dataset = MNIST('mnist', train=True, download=True)
test_dataset = MNIST('mnist_test', train=False, download=True)

X_train, Y_train = preprocess(train_dataset)
X_test, Y_test = preprocess(test_dataset)

model = jp.Sequential(
	jp.Linear(784, 256),
	jp.PReLU(256),
	
	jp.Linear(256, 256),
	jp.PReLU(256),
	
	jp.Linear(256, 256),
	jp.PReLU(256),
	
	jp.Linear(256, 10)
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting mnist/MNIST/raw/train-images-idx3-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting mnist/MNIST/raw/train-labels-idx1-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist_test/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting mnist_test/MNIST/raw/train-images-idx3-ubyte.gz to mnist_test/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist_test/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting mnist_test/MNIST/raw/train-labels-idx1-ubyte.gz to mnist_test/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist_test/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting mnist_test/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist_test/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist_test/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting mnist_test/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist_test/MNIST/raw



In [4]:
optim = jp.AdamW(model.parameters(), lr=0.001)

for epoch in range(n_epochs):
	total = 0
	sum_loss = 0
	
	model.train()
	for X, Y in loader(X_train, Y_train, batch_size):
		pred = model(X)
		
		loss = jp.cross_entropy(pred, Y)
		loss.backward()
		
		optim.step()
		optim.reset_grad()
		
		total += 1
		sum_loss += loss.data
	
	print(f'Epoch {epoch + 1}')
	print(f'Train loss: {sum_loss / total}')
	
	total = 0
	correct = 0
	model.eval()
	
	for X, Y in loader(X_test, Y_test, batch_size, shuffle=False):
		pred = model(X)
		y_hat = pred.data.argmax(axis=1)
		
		correct_pred = np.count_nonzero(y_hat == Y.data)
		
		total += Y.data.shape[0]
		correct += correct_pred
	
	print(f'Test accuracy: {correct / total} \n')

Epoch 1
Train loss: 0.2675954338611041
Test accuracy: 0.9524238782051282 

Epoch 2
Train loss: 0.10437040211376414
Test accuracy: 0.9646434294871795 

Epoch 3
Train loss: 0.06946764193336885
Test accuracy: 0.9761618589743589 

Epoch 4
Train loss: 0.04954595908662705
Test accuracy: 0.9737580128205128 

Epoch 5
Train loss: 0.04067457241927412
Test accuracy: 0.9728565705128205 

Epoch 6
Train loss: 0.030917226211652202
Test accuracy: 0.9746594551282052 

Epoch 7
Train loss: 0.02591325672681328
Test accuracy: 0.9746594551282052 

Epoch 8
Train loss: 0.025156121921669394
Test accuracy: 0.9776642628205128 

Epoch 9
Train loss: 0.01742678285906311
Test accuracy: 0.9794671474358975 

Epoch 10
Train loss: 0.02093793372408384
Test accuracy: 0.9784655448717948 

Epoch 11
Train loss: 0.017468122337778832
Test accuracy: 0.9772636217948718 

Epoch 12
Train loss: 0.014536529555782603
Test accuracy: 0.9782652243589743 

Epoch 13
Train loss: 0.01350467938998796
Test accuracy: 0.9769631410256411 

Epoch