In [85]:
import torch
import numpy as np
import cupy as cpy
from tqdm import tqdm
import eznf
from eznf import Tensor
from eznf import datasets
from eznf import optim

In [86]:
class test(eznf.nn.Module):
    def __init__(self):
        super().__init__()
        self.networks = [
            # eznf.nn.Cov2d(1, 3, 3),
            # eznf.nn.MaxPooling(2),
            # eznf.nn.Flatten(),
            eznf.nn.Linear(784, 256),
            eznf.nn.ReLU(),
            eznf.nn.Linear(256, 10)
        ]
    
    def forward(self, x):
        for i in self.networks:
            x = i(x)
        return x

In [87]:
dataset = datasets.MNIST('./', False)
data = dataset.get()

In [88]:
X_train, Y_train, X_test, Y_test = data
X_train = X_train / 255
X_test = X_test / 255
X_train = np.vstack([i.flatten() for i in X_train.item])
X_test = np.vstack([i.flatten() for i in X_test.item])

Y_train = eznf.one_hot(Tensor(Y_train), 10)
X_train = eznf.Tensor(X_train, requires_grad=False)
Y_test = eznf.one_hot(Tensor(Y_test), 10)
# X_train = X_train[:,None,:,:]

In [89]:
def zero_grad(m):
    for w in m.parameters():
        w.grad = None

def SGD(m: eznf.nn.Module, alpha):
    # 梯度下降
    for w in m.parameters():
        w.item = w.item - alpha*w.grad.item

def accuracy(m, x: eznf.Tensor, y: eznf.Tensor):
    pre = m(x).argmax(axis=0)
    return ((y.item.argmax(axis=0) == pre.item).sum() / x.shape[1]).round(2)

In [105]:
epoches = 5
batch_size = 1024
steps = len(X_train) // batch_size

m = test()
loss = eznf.nn.CrossEntropyLoss()
opt = optim.SGD(alpha=0.01, model=m)

ls = []
acc = []

with tqdm(total=epoches) as t:
    for i in range(epoches):
        for j in range(steps):
            x = X_train[j*batch_size : (j+1)*batch_size]
            y = Y_train[j*batch_size : (j+1)*batch_size]
            out = m(x.T)
            l = loss(out, y.T) / batch_size
            l.backward()
            opt.step()
            opt.zero_grad()
            
            ls.append(l.item)
            acc.append(accuracy(m, x.T, y.T))
        
        t.set_description('Epoch {}'.format(i), refresh=False)
        t.set_postfix(loss=l.item[0], acc=acc[-1], refresh=False)
        t.update(1)

Epoch 4: 100%|██████████| 5/5 [00:06<00:00,  1.21s/it, acc=0.16, loss=2.32]


In [103]:
X_test.T.shape

(784, 10000)

In [92]:
from time import time

a = eznf.ones(1000, 1000)
t1 = time()
for i in range(100):
    b = a @ a
print('CPU: ', time() - t1)

a.to('gpu')
t1 = time()
for i in range(100):
    b = a @ a
print('GPU: ', time() - t1)

CPU:  2.1707780361175537
GPU:  0.005109548568725586
