In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics
import os
os.environ['TF_CPP_LOG_LEVEL']='2'

In [2]:
def preprocess(x,y):
    """ 
    x is a simple imges,not a batch, 32*32*3
    """
    #[0 255]->[0,1]->[-1,1]
    x=2*tf.cast(x,dtype=tf.float32)/255.-1
    y=tf.cast(y,dtype=tf.int32)
    
    return x,y

batchsz=128
(x,y),(x_val,y_val)=datasets.cifar10.load_data()
# y 是[50k,1] -> 50k
y=tf.squeeze(y)
y_val=tf.squeeze(y_val)

y=tf.one_hot(y,depth=10)
y_val=tf.one_hot(y_val,depth=10)

print('datasets train:',x.shape,y.shape,x.min(),x.max())
print('datasets validate:',x_val.shape,y_val.shape,x_val.min(),x_val.max())


datasets train: (50000, 32, 32, 3) (50000, 10) 0 255
datasets validate: (10000, 32, 32, 3) (10000, 10) 0 255


In [3]:
# 构建数据集
train_db=tf.data.Dataset.from_tensor_slices((x,y))
train_db=train_db.map(preprocess).shuffle(10000).batch(batchsz)

val_db=tf.data.Dataset.from_tensor_slices((x_val,y_val))
val_db=val_db.map(preprocess).batch(batchsz)

sample=next(iter(train_db))
print('sampel batch:',sample[0].shape,sample[1].shape)

sampel batch: (128, 32, 32, 3) (128, 10)


In [8]:
class MyDense(layers.Layer):
    # to replace standard layers.Dense()
    def __init__(self,inp_dim,outp_dim):
        super(MyDense,self).__init__()
    
        self.kernal=self.add_weight('w',[inp_dim,outp_dim])
        #self.bias=self.add_weight('b',outp_dim) # 故意去掉bias
    
    def call(self,inputs,training=None):
        x=inputs@self.kernal
        return x
    
class MyNetwork(keras.Model):
    def __init__(self):
        super(MyNetwork,self).__init__()
        self.fc1=MyDense(32*32*3,256)
        self.fc2=MyDense(256,128)
        self.fc3=MyDense(128,64)
        self.fc4=MyDense(64,32)
        self.fc5=MyDense(32,10)
        
    def call(self,inputs,training=None):
        """
        :param input: b,32,32,3
        :param training:
        :return:
        """
        x=tf.reshape(inputs,[-1,32*32*3])
        #[b,32*32*3]->[b,256]->[b,128]->[b,64]->[b,32]->[b,10]
        x=self.fc1(x)
        x=tf.nn.relu(x)
        x=self.fc2(x)
        x=tf.nn.relu(x)
        x=self.fc3(x)
        x=tf.nn.relu(x)
        x=self.fc4(x)
        x=tf.nn.relu(x)
        x=self.fc5(x)
        
        return x

In [9]:
network=MyNetwork()
network.compile(optimizer=optimizers.Adam(lr=1e-3),
               loss=tf.losses.CategoricalCrossentropy(from_logits=True),
               metrics=['accuracy']
               )
network.fit(train_db,epochs=10,validation_data=val_db,validation_freq=1)

Train for 391 steps, validate for 79 steps
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x23bc5051a88>

In [10]:
network.evaluate(val_db)
network.save_weights('ckpt/saved_weights')
del network
print('saved the weights')

network=MyNetwork()
network.compile(optimizer=optimizers.Adam(lr=1e-3),
               loss=tf.losses.CategoricalCrossentropy(from_logits=True),
               metrics=['accuracy']
               )
network.load_weights('ckpt/saved_weights')
network.evaluate(val_db)


saved the weights


[1.4688046265252028, 0.52]