In [43]:
import tensorflow as tf

In [44]:
from typing import Any


class NaiveDense:
    def __init__(self,input_size,output_size,activation) -> None:
        self.activation=activation
        w_shape=(input_size,output_size)
        w_values=tf.random.uniform(w_shape,minval=0,maxval=1e-1)
        self.W=tf.Variable(w_values)
        b_shape=(output_size,)
        b_values=tf.zeros(b_shape)
        self.B=tf.Variable(b_values)
    def __call__(self, input) -> Any:
        return self.activation(tf.matmul(input,self.W)+self.B)
    
    @property
    def weights(self):
        return [self.W,self.B]

In [45]:
class NaiveSequential:
    def __init__(self,layers:list[NaiveDense]) -> None:
        self.layers=layers

    def __call__(self, input) -> Any:
        for layer in self.layers:
            input=layer(input)
        return input
    
    @property
    def weights(self):
        w=[]
        for layer in self.layers:
            w+=layer.weights
        return w

In [46]:
model=NaiveSequential([
    NaiveDense(28*28,512,activation=tf.nn.relu),
    NaiveDense(512,10,activation=tf.nn.softmax)
])
assert len(model.weights)==4

In [47]:
class BatchGenerator:
    def __init__(self,samples:list,labels:list,batch_size) -> None:
        self.samples=samples
        self.labels=labels
        self.batches=(len(samples)-1)//batch_size+1
        self.batch_size=batch_size
        self.index=0
    
    def next(self):
        samples=self.samples[self.index:self.index+self.batch_size]
        labels=self.labels[self.index:self.index+self.batch_size]
        self.index+=self.batch_size
        return (samples,labels)
        

In [48]:
from tensorflow.keras import optimizers

optimizer=optimizers.SGD(learning_rate=1e-3)
def update_weights(gradients,weights):
    optimizer.apply_gradients(zip(gradients,weights))

def traning_step(model:NaiveSequential,samples,labels):
    with tf.GradientTape() as tape:
        predictions=model(samples)
        tot_losses=tf.keras.losses.sparse_categorical_crossentropy(labels,predictions)
        avg_losses=tf.reduce_mean(tot_losses)
        gradients=tape.gradient(avg_losses,model.weights)
        update_weights(gradients,model.weights)
        return avg_losses

# 手动更新权重
# learning_rate=1e-3
# def update_weights(gradients,weights):
#     for g,w in zip(gradients,weights):
#         w.assign_sub(g*learning_rate)



In [49]:
def fit(model,samples,labels,epochs,batch_size):
    for epoch_cnt in range(epochs):
        print('Eposh:',epoch_cnt)
        batch=BatchGenerator(samples,labels,batch_size)
        for i in range(batch.batches):
            s,l=batch.next()
            loss=traning_step(model,s,l)
            if i%100==0:
                print(f'Loss at batch[{i}] -> {loss:.2f}')


In [50]:
from tensorflow.keras.datasets import mnist

(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
train_images=train_images.reshape((60000,28*28))
train_images=train_images.astype('float32')/255
test_images=test_images.reshape((10000,28*28))
test_images=test_images.astype('float32')/255

fit(model,train_images,train_labels,10,128)

Eposh: 0
Loss at batch[0] -> 5.07
Loss at batch[100] -> 2.27
Loss at batch[200] -> 2.20
Loss at batch[300] -> 2.10
Loss at batch[400] -> 2.23
Eposh: 1
Loss at batch[0] -> 1.92
Loss at batch[100] -> 1.91
Loss at batch[200] -> 1.82
Loss at batch[300] -> 1.72
Loss at batch[400] -> 1.84
Eposh: 2
Loss at batch[0] -> 1.59
Loss at batch[100] -> 1.60
Loss at batch[200] -> 1.50
Loss at batch[300] -> 1.43
Loss at batch[400] -> 1.52
Eposh: 3
Loss at batch[0] -> 1.33
Loss at batch[100] -> 1.36
Loss at batch[200] -> 1.23
Loss at batch[300] -> 1.22
Loss at batch[400] -> 1.28
Eposh: 4
Loss at batch[0] -> 1.14
Loss at batch[100] -> 1.18
Loss at batch[200] -> 1.04
Loss at batch[300] -> 1.05
Loss at batch[400] -> 1.11
Eposh: 5
Loss at batch[0] -> 0.99
Loss at batch[100] -> 1.03
Loss at batch[200] -> 0.90
Loss at batch[300] -> 0.93
Loss at batch[400] -> 0.99
Eposh: 6
Loss at batch[0] -> 0.88
Loss at batch[100] -> 0.92
Loss at batch[200] -> 0.79
Loss at batch[300] -> 0.84
Loss at batch[400] -> 0.90
Eposh:

In [54]:
import numpy as np
predictions=model(test_images)

predictions=predictions.numpy()
predicted_labels=np.argmax(predictions,axis=1)
matches=predicted_labels==test_labels
print(f'Accuracy:,{matches.mean():.2f}')

KeyboardInterrupt: 