In [None]:
import scipy.io as scio
import numpy as np
import tensorflow as tf
import keras
from keras.layers import Input, Dense, ZeroPadding2D, Dropout, Activation
from keras.layers import Input, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D
from keras.models import Model
import matplotlib.pyplot as plt
import pandas as pd
from tensorflow.keras import layers
import time
# from sklearn.metrics import classification_report
import warnings
from numba import cuda
from keras.utils.vis_utils import plot_model

import os

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = {'batch': [], 'epoch': []}
        self.accuracy = {'batch': [], 'epoch': []}
        self.val_loss = {'batch': [], 'epoch': []}
        self.val_acc = {'batch': [], 'epoch': []}

    def on_batch_end(self, batch, logs={}):
        self.losses['batch'].append(logs.get('loss'))
        self.accuracy['batch'].append(logs.get('acc'))
        self.val_loss['batch'].append(logs.get('val_loss'))
        self.val_acc['batch'].append(logs.get('val_acc'))

    def on_epoch_end(self, batch, logs={}):
        self.losses['epoch'].append(logs.get('loss'))
        self.accuracy['epoch'].append(logs.get('acc'))
        self.val_loss['epoch'].append(logs.get('val_loss'))
        self.val_acc['epoch'].append(logs.get('val_acc'))

    def loss_plot(self, loss_type, imgpath):
        iters = range(len(self.losses[loss_type]))
        plt.figure()
        # acc
        plt.plot(iters, self.accuracy[loss_type], 'r', label='train acc')
        # loss
        plt.plot(iters, self.losses[loss_type], 'g', label='train loss')
        if loss_type == 'epoch':
            # val_acc
            plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc')
            # val_loss
            plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
        plt.grid(True)
        plt.xlabel(loss_type)
        plt.ylabel('acc-loss')
        plt.legend(loc="upper right")
        plt.savefig(imgpath)
#         plt.show()

%run Networks.ipynb
%run Dataset.ipynb
def doOrigin(num,imgpath,learnR,dataType,NetworkName,inputShape,Epochs,DsetName):
    os.environ["CUDA_VISIBLE_DEVICES"]="0"
    warnings.filterwarnings('ignore')
    start = "Training "+ str(num)+ " Start\n"
    print(start)
    
#     get data
    if dataType == "Origin":
        data_train,data_test,label_trian,label_test = GetOrigin(5,inputShape)
    if dataType == "Spiral":
        data_train,data_test,label_trian,label_test = GetSpiral(inputShape)
    if dataType == "New":
        data_train,data_test,label_trian,label_test = GetNew(inputShape,DsetName)

        
#    get model
    if NetworkName == "ResNet18":
        model = ResNet18(input_shape=inputShape, classes=10)
    if NetworkName == "VGG16":
        model = VGG16(input_shape=inputShape, classes=10)
    if NetworkName == "ANN":
        model = ANN(input_shape=inputShape, classes=10)
    if NetworkName == "ResNet50":
        model = ResNet50(input_shape=inputShape, classes=10)
    if NetworkName == "CNN":
        model = CNN(input_shape=inputShape, classes=10)
    if NetworkName == "EMGNet":
        model = EMGNet(input_shape=inputShape, classes=10)
    

    start = time.time()

    model.compile(optimizer=tf.keras.optimizers.Adam(lr=learnR, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.00001, amsgrad=False), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    history = LossHistory() # 创建一个history实例

    model.fit(data_train, label_trian, epochs=Epochs, batch_size=128, verbose=0,
                validation_data=(data_test, label_test),callbacks=[history])

    model.save('resnet18.h5')
    preds_train = model.evaluate(data_train, label_trian, verbose=0)
#     print("Train Loss = " + str(preds_train[0]))
#     print("Train Accuracy = " + str(preds_train[1]))

    preds_test  = model.evaluate(data_test, label_test, verbose=0)
    end = time.time()
    Time="time=" + str(round(end-start,4))
    Loss="Loss = " + str(round(preds_test[0],4))
    Acc="Accuracy = " + str(round(preds_test[1],4))
    text ="Number of train: "+ str(num) + "," + Loss + "," + Acc + "," + Time + "\n"
    acc=round(preds_test[1],4)

    loss_type = 'epoch'
    history.loss_plot(loss_type,imgpath)
    return text,acc