# 使用keras架构实现MNIST两位数字比较

#### 主要工作
该实验相对之前两个实验的特殊性在于需要自己想办法构建一个mnist数字比较数据集。因此第一步是将数据集构建出来，这里不必保存到本地。具体来讲，思路是：

1. 挑选61000条二分类样本(这个规模已经够大了)， 我们让模型不重复地随机生成61000个彼此不同的下标二元组，作为数据索引。
2. 解包这些二元组形成两个索引列表，查找x_train中对应的内容：x1_train,x2_train，这些共60000条；剩下的1000条作为x1_test和x2_test。
3. 生成label.先得到y1和y2，然后统一比大小得到二分类标签0或者1。最终得到y_train,y_test。

In [None]:
import keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
assert x_train.shape == (60000, 28, 28)
assert x_test.shape == (10000, 28, 28)
assert y_train.shape == (60000,)
assert y_test.shape == (10000,)

随便显示一个图片

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# 假设mnist_sample是一个(28, 28)的数组
mnist_sample = x_train[2]
mnist_label=y_train[2]
plt.imshow(mnist_sample, cmap='gray')
plt.axis('off')
plt.show()
print(mnist_label)


生成并划分

In [None]:
import random
import numpy as np

def generate_index(n, range_start, range_end,y):
    flag=0
    if n > (range_end - range_start + 1) ** 2:
        raise ValueError("no enough indexs")

    unique_numbers = set()
    while len(unique_numbers) < n:
        num1 = random.randint(range_start, range_end)
        num2 = random.randint(range_start, range_end)
        if flag==0 and y[num1]==y[num2]:
            continue
        elif flag==1 and y[num1]!=y[num2]:
            continue
        unique_numbers.add((num1, num2))
        flag=not flag
    x1,x2=[list(t) for t in zip(*unique_numbers)]
    return x1,x2

def make_dataset(x_train,y_train,x_test,y_test):
    x1_train_idx,x2_train_idx = generate_index(100000, 0, 6000,y_train[:6000])
    x1_test_idx,x2_test_idx=generate_index(10000,0,1000,y_test[:1000])
    
    x1_bi_train=np.stack(np.take(x_train,x1_train_idx,axis=0),axis=0)
    x2_bi_train=np.stack(np.take(x_train,x2_train_idx,axis=0),axis=0)
    x_bi_train=np.stack((x1_bi_train,x2_bi_train),axis=-1)
    y1_bi_train=np.stack(np.take(y_train,x1_train_idx,axis=0),axis=0)
    y2_bi_train=np.stack(np.take(y_train,x2_train_idx,axis=0),axis=0)
    y_bi_train=np.array([[0,1] if (a == b) else [1,0] for a, b in zip(y1_bi_train, y2_bi_train)])

    x1_bi_test=np.stack(np.take(x_test,x1_test_idx,axis=0),axis=0)
    x2_bi_test=np.stack(np.take(x_test,x2_test_idx,axis=0),axis=0)
    x_bi_test=np.stack((x1_bi_test,x2_bi_test),axis=-1)
    y1_bi_test=np.stack(np.take(y_test,x1_test_idx,axis=0),axis=0)
    y2_bi_test=np.stack(np.take(y_test,x2_test_idx,axis=0),axis=0)
    y_bi_test=np.array([[0,1] if (a == b) else [1,0] for a, b in zip(y1_bi_test, y2_bi_test)])
    
    print('Balanced dataset\nTrain: {} : {},shape: {} \nTest: {} : {},shape: {}'.format(y_bi_train[:,1].sum(),len(y_bi_train),x_bi_train.shape,y_bi_test[:,1].sum(),len(y_bi_test),x_bi_test.shape))
    
    return (x_bi_train,y_bi_train),(x_bi_test,y_bi_test)

training_set,testing_set=make_dataset(x_train,y_train,x_test,y_test)

model

In [None]:
import keras
from keras.layers import *
from keras.models import Sequential

def residual_blocks(num_filter,input):
    conv1 = Conv2D(filters=num_filter, kernel_size=(2, 2), strides=1, padding='same')(input)
    bn1 = BatchNormalization(axis=-1)(conv1)
    conv2 = Conv2D(filters=num_filter, kernel_size=(2, 2), strides=1, padding='same')(bn1)
    bn2 = BatchNormalization(axis=-1)(conv2)
    conv3 = Conv2D(filters=num_filter, kernel_size=(2, 2), strides=1, padding='same')(bn2)
    bn3 = BatchNormalization(axis=-1)(conv3)
    res=concatenate([input,bn3],axis=-1)
    return res
    # return bn3

def feature_extraction(input_shape=(28,28,2)):
    X_input = Input(input_shape)
    #0填充
    X = ZeroPadding2D((2,2))(X_input)

    res1=residual_blocks(num_filter=6,input=X)

    pool2=MaxPooling2D(strides=2)(res1)
    res2=residual_blocks(num_filter=16,input=pool2)

    pool3=MaxPool2D(strides=2)(res2)
    out=Flatten()(pool3)
    model=keras.Model(inputs=X_input,outputs=out,name='feature_extraction')
    return model

def classifier_10(input_shape=(1536,)):
    X_input=Input(input_shape)
    dense1=Dense(64,activation='relu')(X_input)
    out=Dense(10,activation='softmax')(dense1)
    model=keras.Model(inputs=X_input,outputs=out,name='10_head_classifier')
    return model

def classifier_2(input_shape=(1536,)):
    X_input=Input(input_shape)
    dense1=Dense(256,activation='relu')(X_input)
    dense2=Dense(32,activation='relu')(dense1)
    out=Dense(2,activation='softmax')(dense2)
    model=keras.Model(inputs=X_input,outputs=out,name='2_head_classifier')
    return model

Feature_Extraction=feature_extraction()
Cls_10=classifier_10()
Cls_2=classifier_2()

使用迁移学习：首先是单个数字识别模型的训练，然后封锁线性层转而训练二分类的分类头。

In [None]:
model_10=keras.Sequential()
model_10.add(Feature_Extraction)
model_10.add(Cls_10)
model_10.summary()

train 10-classifier

In [None]:
from keras.utils import to_categorical
y_10_train=to_categorical(y_train,num_classes=10)
x_10_train=np.stack((x_train,x_train),axis=-1)
model_10.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
model_10.fit(x_10_train,y_10_train,
             epochs=2,
             batch_size=64,
             verbose=1)
model_10.layers[0].save('feature_extraction.h5')

In [None]:
model_10.layers[1].save('10_head_classifier.h5')

train header of 2-head-classifier

In [None]:
from keras.models import load_model
model_2=keras.Sequential()
Pretrained_Feature_Extraction=load_model(r'.\feature_extraction.h5',
                                         custom_objects=None, 
                                         compile=True, 
                                         options=None,
                                         )#报错是因为绝对路径中有中文名称
model_2.add(Pretrained_Feature_Extraction)
model_2.add(Cls_2)
model_2.layers[0].trainable=False#freeze params
model_2.layers[1].trainable=True
model_2.summary()

In [None]:
import random
from keras.optimizers import Adam

class LossHistory(keras.callbacks.Callback):
    def __init__(self,batchsize):
        super().__init__()
        self.losses = []
        self.accs=[]
        self.val_losses=[]
        self.val_accs=[]
        self.x_test = testing_set[0]
        self.y_test = testing_set[1]
        # print(batch)
        self.batch_size = batchsize

    def on_batch_end(self, batch, logs={}):
        indices = random.sample(range(10000), self.batch_size)
        x_batch = self.x_test[indices]
        y_batch = self.y_test[indices]
        # print(y_batch.shape)
        val_loss, val_acc = self.model.evaluate(x_batch, y_batch, verbose=0)
        self.losses.append(logs.get('loss'))
        self.accs.append(logs.get('accuracy'))
        self.val_losses.append(val_loss)
        self.val_accs.append(val_acc)


model_2.compile(optimizer='Adam', loss="categorical_crossentropy", metrics=["accuracy"])

bs=64
callback = LossHistory(batchsize=bs)
# 训练模型
history_small = model_2.fit(
    training_set[0], training_set[1],
    epochs=3,
    batch_size=bs,
    verbose=1,
    callbacks=[callback]
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
updating_steps_loss = range(20,len(callback.losses))
updating_steps = range(len(callback.losses))


# 绘制图表
plt.figure(figsize=(12, 6))

# losses和val_losses
plt.subplot(221)
plt.plot(updating_steps_loss, callback.losses[20:], label='training_loss', color='blue')
# plt.xlabel('updating steps')
plt.ylabel('loss')
plt.legend()

# accs和val_accs
plt.subplot(222)
plt.plot(updating_steps,callback.accs, label='training_accuracy', color='blue')
# plt.xlabel('updating steps')
plt.ylabel('accuracy')
plt.legend()


# losses和val_losses
plt.subplot(223)
plt.plot(updating_steps_loss, callback.val_losses[20:], label='validation_loss', color='orange')
plt.xlabel('updating steps')
plt.ylabel('validation loss')
plt.legend()

# accs和val_accs
plt.subplot(224)
plt.plot(updating_steps,callback.val_accs, label='validation_accuracy', color='orange')
plt.xlabel('updating steps')
plt.ylabel('validation accuracy')
plt.legend()

plt.show()
# plt.savefig('output_bs128.png')


In [None]:
plt.savefig('output.png')

测试集上的准确率

In [None]:
from sklearn.metrics import classification_report
predictions=model_2.predict(testing_set[0],batch_size=1)

In [None]:
repo=classification_report(testing_set[1][:,1],predictions.argmax(axis=-1),target_names=['not_same','same'])
with open('bs128.txt','w')as f:
    print('report:\n',repo,file=f)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
 
def show_values(pc, fmt="%.2f", **kw):
    import numpy as np
    pc.update_scalarmappable()
    ax = pc.axes
    for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
        x, y = p.vertices[:-2, :].mean(0)
        if np.all(color[:3] > 0.5):
            color = (0.0, 0.0, 0.0)
        else:
            color = (1.0, 1.0, 1.0)
        ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)

 
 
def cm2inch(*tupl):
    '''
    Specify figure size in centimeter in matplotlib
    Source: https://stackoverflow.com/a/22787457/395857
    By gns-ank
    '''
    inch = 2.54
    if type(tupl[0]) == tuple:
        return tuple(i/inch for i in tupl[0])
    else:
        return tuple(i/inch for i in tupl)
 
 
def heatmap(AUC, title, xlabel, ylabel, xticklabels, yticklabels, figure_width=40, figure_height=20, correct_orientation=False, cmap='RdBu'):
    '''
    Inspired by:
    - https://stackoverflow.com/a/16124677/395857 
    - https://stackoverflow.com/a/25074150/395857
    '''
 
    # Plot it out
    fig, ax = plt.subplots()    
    #c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap='RdBu', vmin=0.0, vmax=1.0)
    c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap=cmap)
 
    # put the major ticks at the middle of each cell
    ax.set_yticks(np.arange(AUC.shape[0]) + 0.5, minor=False)
    ax.set_xticks(np.arange(AUC.shape[1]) + 0.5, minor=False)
 
    # set tick labels
    #ax.set_xticklabels(np.arange(1,AUC.shape[1]+1), minor=False)
    ax.set_xticklabels(xticklabels, minor=False)
    ax.set_yticklabels(yticklabels, minor=False)
 
    # set title and x/y labels
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)      
 
    # Remove last blank column
    plt.xlim( (0, AUC.shape[1]) )
 
    # Turn off all the ticks
    ax = plt.gca()    
    for t in ax.xaxis.get_major_ticks():
        t.tick1On = False
        t.tick2On = False
    for t in ax.yaxis.get_major_ticks():
        t.tick1On = False
        t.tick2On = False
 
    # Add color bar
    plt.colorbar(c)
 
    # Add text in each cell 
    show_values(c)
 
    # Proper orientation (origin at the top left instead of bottom left)
    if correct_orientation:
        ax.invert_yaxis()
        ax.xaxis.tick_top()       
 
    # resize 
    fig = plt.gcf()
    #fig.set_size_inches(cm2inch(40, 20))
    #fig.set_size_inches(cm2inch(40*4, 20*4))
    fig.set_size_inches(cm2inch(figure_width, figure_height))
 
 
 
def plot_classification_report(classification_report, title='Classification report ', cmap='RdBu'):
    '''
    Plot scikit-learn classification report.
    Extension based on https://stackoverflow.com/a/31689645/395857 
    '''
    lines = classification_report.split('\n')
 
    classes = []
    plotMat = []
    support = []
    class_names = []
    for line in lines[2 : (len(lines) - 4)]:
        t = line.strip().split()
        if len(t) < 2: continue
        classes.append(t[0])
        print(t)
        # for x in t[1: len(t) - 1]:
        #     print(x)
        if 'avg' in t[1]:
            v = [float(x) for x in t[2: len(t) - 1]]
        else:
            v = [float(x) for x in t[1: len(t) - 1]]
        support.append(int(t[-1]))
        class_names.append(t[0])
        print(v)
        plotMat.append(v)
 
    print('plotMat: {0}'.format(plotMat))
    print('support: {0}'.format(support))
 
    xlabel = 'Metrics'
    ylabel = 'Classes'
    xticklabels = ['Precision', 'Recall', 'F1-score']
    yticklabels = ['{0} ({1})'.format(class_names[idx], sup) for idx, sup  in enumerate(support)]
    figure_width = 25
    figure_height = len(class_names) + 7
    correct_orientation = False
    heatmap(np.array(plotMat), title, xlabel, ylabel, xticklabels, yticklabels, figure_width, figure_height, correct_orientation, cmap=cmap)
 
 
def main():
    sampleClassificationReport = repo
 
    plot_classification_report(sampleClassificationReport)
    plt.savefig('test_plot_classif_report.png', dpi=200, format='png', bbox_inches='tight')
    plt.close()
 
if __name__ == "__main__":
    main()
    #cProfile.run('main()') # if you want to do some profiling