In [64]:
import numpy as np
from sklearn.datasets import load_digits
from sklearn.preprocessing import StandardScaler

In [65]:
from sklearn.datasets import fetch_openml
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
# X, y = load_digits(return_X_y=True)


In [66]:
y = y.astype(np.int32)

In [67]:
y = np.eye(10)[y]

In [68]:
X = StandardScaler().fit_transform(X)


In [69]:
X.shape, y.shape

((70000, 784), (70000, 10))

In [70]:
class Linear:
    def __init__(self, in_dim, out_dim):
        self.w = np.random.randn(in_dim, out_dim)
        self.b = np.zeros([out_dim])
        self.dw = None
        self.db = None
        self.x = None
        
    def forward(self, x):
        self.x = x
        return np.matmul(x, self.w) + self.b
    
    def backward(self, d):
        self.dw = np.matmul(self.x.T, d)
        self.db = np.sum(d, axis=0)
        assert self.dw.shape == self.w.shape
        assert self.db.shape == self.b.shape

        d = d @ self.w.T
        return d
        
    def step(self, lr):
        self.w = self.w - lr * self.dw
        self.b = self.b - lr * self.db

In [71]:
class ReLU:
    def __init__(self):
        self.a = None
        
    def forward(self, x):
        self.a = np.maximum(x, 0)
        return self.a
    
    def backward(self, d):
        return d * (self.a != 0).astype(np.float32)
        

In [72]:
class Softmax:
    def __init__(self):
        self.a = None
        
    def forward(self, x):
        assert len(x.shape) == 2
        x = x - np.max(x, axis=-1, keepdims=True)
        self.a = np.exp(x) / np.sum(np.exp(x), keepdims=True, axis=-1)
        return self.a
    def backward(self, d):
        sigma = np.repeat(self.a[:, np.newaxis, ...], self.a.shape[1], 1)
        j = np.swapaxes(sigma, 1, 2) * (np.eye(sigma.shape[1])[np.newaxis, ...] - sigma)
        return (j @ d[..., np.newaxis]).squeeze(-1)


In [73]:
class CrossEntropy:
    def __init__(self, average=True): 
        self.average = average
    def forward(self, y_, y):
        
        l = - np.sum(y * np.log(y_ + 1e-9))
        if self.average:
            l /= len(y)
        return y_, l
    
    def backward(self, y_, y):
        assert y_.shape == y.shape
        d =  - y / (y_ + 1e-9)
        if self.average:
            d /= len(y)
        return d

In [89]:
class LogisticRegression:
    def __init__(self):
        self.linear1 = Linear(784, 128)
        self.relu1 = ReLU()
        self.linear2 = Linear(128, 32)
        self.relu2 = ReLU()
        self.linear3 = Linear(32, 10)
        self.softmax = Softmax()
        self.loss = CrossEntropy()
        
    def forward(self, x, y):
        x = self.linear1.forward(x)
        x = self.relu1.forward(x)
        x = self.linear2.forward(x)
        x = self.relu2.forward(x)
        x = self.linear3.forward(x)
        x = self.softmax.forward(x)
        loss = self.loss.forward(x, y)
        return loss
    
    def backward(self, y_, y):
        d = self.loss.backward(y_, y)
        d = self.softmax.backward(d)
        d = self.linear3.backward(d)
        d = self.relu2.backward(d)
        d = self.linear2.backward(d)
        d = self.relu1.backward(d)
        d = self.linear1.backward(d)

    
    def step(self, lr):
        self.linear1.step(lr)
        self.linear2.step(lr)
        self.linear3.step(lr)

        

In [93]:
logreg = LogisticRegression()

In [94]:
batch_size = 1024

In [95]:
for i in range(10000):
    for j in range(0, len(X), batch_size):
        X_batch = X[j: j + batch_size]
        y_batch = y[j: j + batch_size]
        y_, loss = logreg.forward(X_batch, y_batch)
        logreg.backward(y_, y_batch)
        logreg.step(lr=0.01)
        if j % 32768 == 0:
            print(j)
            print(loss, 'loss')
            print((np.argmax(y_, axis=-1) == np.argmax(y_batch, axis=-1)).sum() / len(y_batch))

0
18.496505292805416 loss
0.107421875
32768
16.946182116003918 loss
0.177734375
65536
16.086337211345644 loss
0.21875
0
17.10940405439259 loss
0.169921875
32768
16.15429415927803 loss
0.216796875
65536
15.39941509923963 loss
0.2529296875
0
16.769827927237053 loss
0.185546875
32768
15.674775865435535 loss
0.240234375
65536
15.167747661812076 loss
0.265625
0
16.058790743877857 loss
0.2197265625
32768
15.101982019679099 loss
0.267578125
65536
14.941254522515113 loss
0.27734375
0
15.553179781141889 loss
0.2412109375
32768
14.760180221880498 loss
0.2861328125
65536
14.398715705364193 loss
0.3037109375
0
15.133704714364086 loss
0.2685546875
32768
14.359092710676993 loss
0.3056640625
65536
14.21783799967715 loss
0.3134765625
0
14.759393980766395 loss
0.283203125
32768
13.871345565210346 loss
0.328125
65536
13.840032218927705 loss
0.326171875
0
14.104736425201262 loss
0.3173828125
32768
13.520520716456282 loss
0.34375
65536
13.458752735508272 loss
0.345703125
0
13.759195960392509 loss
0.332031

KeyboardInterrupt: 