In [1]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import numpy as np
import pickle
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from tensorflow import keras
from keras import models
from keras.models import Sequential
from keras.optimizers import Adam
import tensorflow as tf
import os
import copy

In [2]:
def cross_validate_and_select_best_params(model, X, y, n_splits=5):
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    
    scores = {
        'accuracy': [],
        'precision': [],
        'recall': [],
        'f1': [],
        'roc_auc': []
    }

    for train_index, test_index in kfold.split(X):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        # print(f"y_train shape = {y_train.shape}, y_test shape = {y_test.shape}")
        history = model.fit(X_train, y_train, epochs=5, batch_size=128, validation_data=(X_test, y_test))

        y_pred_prob = model.predict(X_test)
        y_pred = np.argmax(y_pred_prob, axis=1)
        y_test = np.argmax(y_test, axis=1) 
        
        scores['accuracy'].append(accuracy_score(y_test, y_pred))
        scores['precision'].append(precision_score(y_test, y_pred, average='macro'))
        scores['recall'].append(recall_score(y_test, y_pred, average='macro'))
        scores['f1'].append(f1_score(y_test, y_pred, average='macro'))
        scores['roc_auc'].append(roc_auc_score(y_test, y_pred_prob, multi_class='ovr', average='macro'))  # ROC AUC 多分类问题

    avg_scores = {key: np.mean(value) for key, value in scores.items()}
    return avg_scores

In [2]:
import os
import pickle
import numpy as np

def load_cifar10_data(data_dir):
    X = []
    Y = []

    # Load training data batches
    for i in range(1, 6):
        with open(os.path.join(data_dir, f'data_batch_{i}'), 'rb') as file:
            batch = pickle.load(file, encoding='latin1')
            X.append(batch['data'])
            Y.extend(batch['labels'])

    # Combine training data
    X = np.concatenate(X)

    # Load test data
    with open(os.path.join(data_dir, 'test_batch'), 'rb') as file:
        test_batch = pickle.load(file, encoding='latin1')
        X_test = test_batch['data']
        Y_test = test_batch['labels']

    # Combine test data
    X = np.vstack((X, X_test))
    Y.extend(Y_test)

    # Convert labels to numpy array
    Y = np.array(Y)

    # Note: Removing the reshaping operation
    # X = X.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)  # (N, H, W, C)

    # Load label names from batches.meta
    with open(os.path.join(data_dir, 'batches.meta'), 'rb') as file:
        meta = pickle.load(file, encoding='latin1')
        label_names = meta['label_names']

    return X, Y, label_names


In [3]:
data_dir = 'cifar-10-batches-py'  # path directory
X_combined, y_combined, meta = load_cifar10_data(data_dir)

In [5]:
def build_simple_cnn(input_shape, num_classes, learning_rate=0.0005, num_filter=16, filter_size=(3,3), dropout_rate=0.5):
    model = models.Sequential()
    model.add(layers.Conv2D(num_filter, filter_size, activation='relu', input_shape=input_shape, padding='same'))
    model.add(layers.BatchNormalization())
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(128, activation='relu'))
    model.add(layers.Dropout(dropout_rate))
    model.add(layers.Dense(num_classes, activation='softmax'))
    model.compile(optimizer=Adam(learning_rate=learning_rate), loss='categorical_crossentropy', metrics=['accuracy'])
    return model


In [6]:
print(y_combined.shape)

(60000,)


In [64]:
input_shape = (32, 32, 3)
num_classes = 10

num_filters=[8]
learning_rates=[0.001, 0.0005, 0.0015, 0.01, 0.005]
filter_sizes=[(1,1), (3,3), (5,5), (7,7), (9,9)]
dropout_rates=[0, 0.1, 0.3, 0.5, 0.75]

highest_config=[]
highest_accuracy=0
for learning_rate in learning_rates:
    highest_num_filter = 0
    highest_num_filter_accuracy=0
    best_num_filter = False
    for num_filter in num_filters:
        highest_filter_size = (0,0)
        highest_filter_size_accuracy=0
        best_filter_size = False
        for filter_size in filter_sizes:
            highest_dropout_rate = -1
            highest_dropout_rate_accuracy=0
            for dropout_rate in dropout_rates:
                current_config = [learning_rate, num_filter, filter_size, dropout_rate]
                model = build_simple_cnn(input_shape, num_classes, learning_rate, num_filter, filter_size, dropout_rate)

                X_combined_reshaped = X_combined.reshape(-1, 32, 32, 3)
                y_combined_categorical = to_categorical(y_combined, num_classes=10)
                X_combined_reshaped = X_combined_reshaped.astype('float32') / 255.0

                avg_scores = cross_validate_and_select_best_params(model, X_combined_reshaped, y_combined_categorical)
                print("current_config: ", current_config)
                print(avg_scores)

                if avg_scores['accuracy'] > highest_dropout_rate_accuracy:
                    highest_dropout_rate_accuracy = avg_scores['accuracy']
                    highest_dropout_rate = current_config[3]
                else:
                    print("Last dropout_rate is best, break")
                    break

                if avg_scores['accuracy'] >= highest_filter_size_accuracy:
                    highest_filter_size_accuracy = avg_scores['accuracy']
                    highest_filter_size = current_config[2]
                    best_filter_size = False
                elif highest_filter_size_accuracy > avg_scores['accuracy'] and filter_size > highest_filter_size:
                    print("Last filter_size is best, break")
                    best_filter_size = True

                if avg_scores['accuracy'] >= highest_num_filter_accuracy:
                    highest_num_filter_accuracy = avg_scores['accuracy']
                    highest_num_filter = current_config[1]
                    best_num_filter = False
                elif highest_num_filter_accuracy > avg_scores['accuracy'] and num_filter > highest_num_filter:
                    print("Last num_filter is best, break")
                    best_num_filter = True


                if avg_scores['accuracy'] > highest_accuracy:
                    highest_config = current_config
                    highest_accuracy = avg_scores['accuracy']
                    print("highest_config: ", highest_config)
                    print("highest_accuracy: ", highest_accuracy)

            if best_filter_size:
                break
        if best_num_filter:
            break
            
print("highest_config: ", highest_config)
print("highest_accuracy: ", highest_accuracy)



Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 8ms/step - accuracy: 0.3947 - loss: 1.7181 - val_accuracy: 0.4337 - val_loss: 1.6551
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.5795 - loss: 1.2057 - val_accuracy: 0.5570 - val_loss: 1.2822
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.6373 - loss: 1.0387 - val_accuracy: 0.5552 - val_loss: 1.3060
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.6891 - loss: 0.9051 - val_accuracy: 0.5617 - val_loss: 1.3006
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.7365 - loss: 0.7768 - val_accuracy: 0.5665 - val_loss: 1.3299
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.7160 - loss: 0.8515 -

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.3748 - loss: 1.7763 - val_accuracy: 0.3875 - val_loss: 1.7234
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.5659 - loss: 1.2290 - val_accuracy: 0.5590 - val_loss: 1.2756
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.6191 - loss: 1.0828 - val_accuracy: 0.5773 - val_loss: 1.2268
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.6556 - loss: 0.9835 - val_accuracy: 0.5764 - val_loss: 1.2384
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 7ms/step - accuracy: 0.6917 - loss: 0.8901 - val_accuracy: 0.5792 - val_loss: 1.2352
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 975us/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.6760 - loss: 0.9541

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.4043 - loss: 1.7186 - val_accuracy: 0.4821 - val_loss: 1.5053
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.5849 - loss: 1.1801 - val_accuracy: 0.5613 - val_loss: 1.2387
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.6484 - loss: 1.0042 - val_accuracy: 0.5591 - val_loss: 1.2608
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.6925 - loss: 0.8826 - val_accuracy: 0.5932 - val_loss: 1.2184
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.7389 - loss: 0.7617 - val_accuracy: 0.5991 - val_loss: 1.2242
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 974us/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 7ms/step - accuracy: 0.7235 - loss: 0.8097

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.3848 - loss: 1.7725 - val_accuracy: 0.4952 - val_loss: 1.4810
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.5524 - loss: 1.2658 - val_accuracy: 0.5516 - val_loss: 1.2801
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.6066 - loss: 1.1160 - val_accuracy: 0.5507 - val_loss: 1.2786
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.6512 - loss: 0.9976 - val_accuracy: 0.5818 - val_loss: 1.2025
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.6813 - loss: 0.9090 - val_accuracy: 0.5792 - val_loss: 1.2372
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 971us/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.6794 - loss: 0.9336

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.3736 - loss: 1.8289 - val_accuracy: 0.4445 - val_loss: 1.5681
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.5537 - loss: 1.2629 - val_accuracy: 0.5265 - val_loss: 1.3411
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.6103 - loss: 1.1099 - val_accuracy: 0.5669 - val_loss: 1.2464
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.6518 - loss: 0.9902 - val_accuracy: 0.5401 - val_loss: 1.3545
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.6833 - loss: 0.8963 - val_accuracy: 0.5421 - val_loss: 1.3741
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.6810 - loss: 0.9334 -

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 7ms/step - accuracy: 0.3596 - loss: 1.8403 - val_accuracy: 0.4502 - val_loss: 1.5902
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.5173 - loss: 1.3695 - val_accuracy: 0.5253 - val_loss: 1.3493
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.5665 - loss: 1.2259 - val_accuracy: 0.5604 - val_loss: 1.2687
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.6089 - loss: 1.1121 - val_accuracy: 0.5680 - val_loss: 1.2465
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.6406 - loss: 1.0210 - val_accuracy: 0.5802 - val_loss: 1.2291
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.6458 - loss: 1.0289 -

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.4030 - loss: 1.7466 - val_accuracy: 0.4427 - val_loss: 1.6737
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 15ms/step - accuracy: 0.5923 - loss: 1.1663 - val_accuracy: 0.5460 - val_loss: 1.3063
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.6685 - loss: 0.9583 - val_accuracy: 0.5837 - val_loss: 1.2259
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.7228 - loss: 0.7980 - val_accuracy: 0.5838 - val_loss: 1.2843
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.7726 - loss: 0.6634 - val_accuracy: 0.5804 - val_loss: 1.3476
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.7549 - loss: 0.

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 15ms/step - accuracy: 0.3788 - loss: 1.8088 - val_accuracy: 0.4022 - val_loss: 1.7163
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.5276 - loss: 1.3375 - val_accuracy: 0.5389 - val_loss: 1.3205
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 15ms/step - accuracy: 0.5786 - loss: 1.1842 - val_accuracy: 0.3684 - val_loss: 2.3253
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.6186 - loss: 1.0752 - val_accuracy: 0.5457 - val_loss: 1.3576
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.6493 - loss: 0.9934 - val_accuracy: 0.5788 - val_loss: 1.2358
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.6435 - loss: 1.

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.4068 - loss: 1.7502 - val_accuracy: 0.4928 - val_loss: 1.5072
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.5936 - loss: 1.1647 - val_accuracy: 0.5687 - val_loss: 1.2413
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.6506 - loss: 1.0065 - val_accuracy: 0.6006 - val_loss: 1.1700
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 15ms/step - accuracy: 0.6866 - loss: 0.8986 - val_accuracy: 0.5617 - val_loss: 1.3548
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.7178 - loss: 0.8176 - val_accuracy: 0.6139 - val_loss: 1.1707
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.7139 - loss: 0.

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 15ms/step - accuracy: 0.3884 - loss: 1.7765 - val_accuracy: 0.4782 - val_loss: 1.5666
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.5706 - loss: 1.2139 - val_accuracy: 0.5706 - val_loss: 1.2045
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.6346 - loss: 1.0456 - val_accuracy: 0.6096 - val_loss: 1.1377
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.6723 - loss: 0.9331 - val_accuracy: 0.5934 - val_loss: 1.1925
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.7154 - loss: 0.8192 - val_accuracy: 0.5974 - val_loss: 1.2248
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.7025 - loss: 0.

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 16ms/step - accuracy: 0.3376 - loss: 1.9234 - val_accuracy: 0.4619 - val_loss: 1.6822
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - accuracy: 0.5031 - loss: 1.3943 - val_accuracy: 0.5587 - val_loss: 1.2612
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - accuracy: 0.5550 - loss: 1.2549 - val_accuracy: 0.5776 - val_loss: 1.2252
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.5872 - loss: 1.1585 - val_accuracy: 0.5883 - val_loss: 1.1818
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - accuracy: 0.6157 - loss: 1.0877 - val_accuracy: 0.5984 - val_loss: 1.1623
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - accuracy: 0.6235 - loss: 1.

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 16ms/step - accuracy: 0.3586 - loss: 1.8781 - val_accuracy: 0.4511 - val_loss: 1.5910
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - accuracy: 0.5463 - loss: 1.2754 - val_accuracy: 0.5551 - val_loss: 1.2501
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - accuracy: 0.6105 - loss: 1.1103 - val_accuracy: 0.5442 - val_loss: 1.3009
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - accuracy: 0.6541 - loss: 0.9879 - val_accuracy: 0.5654 - val_loss: 1.2847
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - accuracy: 0.6867 - loss: 0.8940 - val_accuracy: 0.5631 - val_loss: 1.3188
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - accuracy: 0.6829 - loss: 0.

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 15ms/step - accuracy: 0.3686 - loss: 1.8245 - val_accuracy: 0.5034 - val_loss: 1.5213
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.5380 - loss: 1.2989 - val_accuracy: 0.4864 - val_loss: 1.5039
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.5882 - loss: 1.1640 - val_accuracy: 0.5707 - val_loss: 1.2238
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.6332 - loss: 1.0440 - val_accuracy: 0.5838 - val_loss: 1.1877
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.6689 - loss: 0.9390 - val_accuracy: 0.5254 - val_loss: 1.5111
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.6626 - loss: 0.

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 15ms/step - accuracy: 0.3129 - loss: 1.9952 - val_accuracy: 0.4752 - val_loss: 1.5984
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.4738 - loss: 1.4668 - val_accuracy: 0.4909 - val_loss: 1.4116
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.5313 - loss: 1.3289 - val_accuracy: 0.3097 - val_loss: 2.6939
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.5634 - loss: 1.2308 - val_accuracy: 0.5479 - val_loss: 1.2907
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.5919 - loss: 1.1524 - val_accuracy: 0.5529 - val_loss: 1.2609
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.5963 - loss: 1.

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 33ms/step - accuracy: 0.4077 - loss: 1.7858 - val_accuracy: 0.4212 - val_loss: 1.6985
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.5809 - loss: 1.1890 - val_accuracy: 0.5460 - val_loss: 1.3024
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.6440 - loss: 1.0214 - val_accuracy: 0.5747 - val_loss: 1.2488
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.6947 - loss: 0.8780 - val_accuracy: 0.5774 - val_loss: 1.2565
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 33ms/step - accuracy: 0.7376 - loss: 0.7563 - val_accuracy: 0.5862 - val_loss: 1.2552
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 33ms/step - accuracy: 0.7260 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 32ms/step - accuracy: 0.3696 - loss: 1.8904 - val_accuracy: 0.3927 - val_loss: 1.7670
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.5604 - loss: 1.2484 - val_accuracy: 0.4983 - val_loss: 1.5036
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.6178 - loss: 1.0761 - val_accuracy: 0.5807 - val_loss: 1.2259
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.6597 - loss: 0.9630 - val_accuracy: 0.5788 - val_loss: 1.2749
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.7002 - loss: 0.8520 - val_accuracy: 0.5957 - val_loss: 1.2232
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.6908 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 32ms/step - accuracy: 0.3123 - loss: 2.0613 - val_accuracy: 0.3524 - val_loss: 1.8276
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.4783 - loss: 1.4770 - val_accuracy: 0.4652 - val_loss: 1.5829
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.5309 - loss: 1.3232 - val_accuracy: 0.5368 - val_loss: 1.3604
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.5635 - loss: 1.2284 - val_accuracy: 0.5678 - val_loss: 1.2597
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.5878 - loss: 1.1516 - val_accuracy: 0.5713 - val_loss: 1.2230
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 33ms/step - accuracy: 0.5860 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 33ms/step - accuracy: 0.4032 - loss: 1.8538 - val_accuracy: 0.4302 - val_loss: 1.6567
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.6211 - loss: 1.0893 - val_accuracy: 0.5703 - val_loss: 1.2138
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.6724 - loss: 0.9344 - val_accuracy: 0.5709 - val_loss: 1.3175
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.7317 - loss: 0.7653 - val_accuracy: 0.6024 - val_loss: 1.1866
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.7699 - loss: 0.6610 - val_accuracy: 0.5978 - val_loss: 1.2526
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 32ms/step - accuracy: 0.7563 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 35ms/step - accuracy: 0.3829 - loss: 1.8687 - val_accuracy: 0.4613 - val_loss: 1.6407
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 35ms/step - accuracy: 0.5749 - loss: 1.2046 - val_accuracy: 0.5634 - val_loss: 1.2526
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 35ms/step - accuracy: 0.6303 - loss: 1.0390 - val_accuracy: 0.6038 - val_loss: 1.1528
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 34ms/step - accuracy: 0.6776 - loss: 0.9185 - val_accuracy: 0.5926 - val_loss: 1.2033
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 34ms/step - accuracy: 0.7220 - loss: 0.7950 - val_accuracy: 0.5863 - val_loss: 1.2752
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 34ms/step - accuracy: 0.7140 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 36ms/step - accuracy: 0.3700 - loss: 2.1105 - val_accuracy: 0.4604 - val_loss: 1.5757
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 35ms/step - accuracy: 0.5673 - loss: 1.2459 - val_accuracy: 0.5520 - val_loss: 1.2750
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 35ms/step - accuracy: 0.6222 - loss: 1.0862 - val_accuracy: 0.4192 - val_loss: 1.8885
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 35ms/step - accuracy: 0.6688 - loss: 0.9577 - val_accuracy: 0.5397 - val_loss: 1.4286
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 35ms/step - accuracy: 0.6969 - loss: 0.8722 - val_accuracy: 0.5610 - val_loss: 1.3569
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 36ms/step - accuracy: 0.6960 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 36ms/step - accuracy: 0.3453 - loss: 1.9958 - val_accuracy: 0.4363 - val_loss: 1.6330
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 36ms/step - accuracy: 0.5207 - loss: 1.3392 - val_accuracy: 0.4652 - val_loss: 1.5008
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 35ms/step - accuracy: 0.5749 - loss: 1.1893 - val_accuracy: 0.5698 - val_loss: 1.2503
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 35ms/step - accuracy: 0.6096 - loss: 1.0967 - val_accuracy: 0.5141 - val_loss: 1.4433
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 34ms/step - accuracy: 0.6459 - loss: 1.0013 - val_accuracy: 0.5664 - val_loss: 1.2750
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 35ms/step - accuracy: 0.6434 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 85ms/step - accuracy: 0.3911 - loss: 1.9874 - val_accuracy: 0.4172 - val_loss: 1.7370
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 83ms/step - accuracy: 0.5818 - loss: 1.1953 - val_accuracy: 0.5573 - val_loss: 1.2658
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 83ms/step - accuracy: 0.6497 - loss: 1.0187 - val_accuracy: 0.5128 - val_loss: 1.5198
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 85ms/step - accuracy: 0.6909 - loss: 0.8874 - val_accuracy: 0.5558 - val_loss: 1.3733
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 83ms/step - accuracy: 0.7338 - loss: 0.7725 - val_accuracy: 0.5862 - val_loss: 1.2688
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 85ms/step - accuracy: 0.7250 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 83ms/step - accuracy: 0.3595 - loss: 2.2933 - val_accuracy: 0.4008 - val_loss: 1.7970
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 84ms/step - accuracy: 0.5243 - loss: 1.3480 - val_accuracy: 0.5213 - val_loss: 1.3604
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 84ms/step - accuracy: 0.5850 - loss: 1.1716 - val_accuracy: 0.5560 - val_loss: 1.2869
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 87ms/step - accuracy: 0.6251 - loss: 1.0565 - val_accuracy: 0.5552 - val_loss: 1.2951
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 85ms/step - accuracy: 0.6576 - loss: 0.9645 - val_accuracy: 0.5790 - val_loss: 1.2193
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 84ms/step - accuracy: 0.6538 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 87ms/step - accuracy: 0.4166 - loss: 1.8722 - val_accuracy: 0.4805 - val_loss: 1.6104
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 86ms/step - accuracy: 0.6228 - loss: 1.0862 - val_accuracy: 0.5437 - val_loss: 1.3208
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 86ms/step - accuracy: 0.6901 - loss: 0.9007 - val_accuracy: 0.6023 - val_loss: 1.1583
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 86ms/step - accuracy: 0.7554 - loss: 0.7162 - val_accuracy: 0.5902 - val_loss: 1.2626
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 86ms/step - accuracy: 0.8062 - loss: 0.5638 - val_accuracy: 0.6094 - val_loss: 1.2595
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 86ms/step - accuracy: 0.7797 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 88ms/step - accuracy: 0.3646 - loss: 2.0456 - val_accuracy: 0.4707 - val_loss: 1.6500
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 86ms/step - accuracy: 0.5582 - loss: 1.2304 - val_accuracy: 0.5208 - val_loss: 1.3807
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 86ms/step - accuracy: 0.6192 - loss: 1.0792 - val_accuracy: 0.5513 - val_loss: 1.3126
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 85ms/step - accuracy: 0.6682 - loss: 0.9320 - val_accuracy: 0.5848 - val_loss: 1.2332
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 87ms/step - accuracy: 0.7112 - loss: 0.8250 - val_accuracy: 0.5954 - val_loss: 1.2520
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 87ms/step - accuracy: 0.7078 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 90ms/step - accuracy: 0.3618 - loss: 2.1510 - val_accuracy: 0.4006 - val_loss: 1.6829
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 89ms/step - accuracy: 0.5634 - loss: 1.2425 - val_accuracy: 0.5321 - val_loss: 1.3739
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 89ms/step - accuracy: 0.6251 - loss: 1.0643 - val_accuracy: 0.4911 - val_loss: 1.5849
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 87ms/step - accuracy: 0.6700 - loss: 0.9478 - val_accuracy: 0.5732 - val_loss: 1.2813
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 88ms/step - accuracy: 0.7112 - loss: 0.8280 - val_accuracy: 0.5919 - val_loss: 1.2484
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 88ms/step - accuracy: 0.7041 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 88ms/step - accuracy: 0.3346 - loss: 2.3804 - val_accuracy: 0.4552 - val_loss: 1.6567
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 88ms/step - accuracy: 0.5004 - loss: 1.3867 - val_accuracy: 0.4752 - val_loss: 1.4563
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 89ms/step - accuracy: 0.5569 - loss: 1.2333 - val_accuracy: 0.5111 - val_loss: 1.4139
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 88ms/step - accuracy: 0.5961 - loss: 1.1328 - val_accuracy: 0.5592 - val_loss: 1.2987
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 88ms/step - accuracy: 0.6337 - loss: 1.0304 - val_accuracy: 0.5726 - val_loss: 1.2391
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 88ms/step - accuracy: 0.6412 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 182ms/step - accuracy: 0.3361 - loss: 2.8128 - val_accuracy: 0.3217 - val_loss: 1.8399
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m68s[0m 182ms/step - accuracy: 0.5577 - loss: 1.2632 - val_accuracy: 0.5440 - val_loss: 1.3036
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m68s[0m 182ms/step - accuracy: 0.6108 - loss: 1.1033 - val_accuracy: 0.5582 - val_loss: 1.2797
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m68s[0m 181ms/step - accuracy: 0.6455 - loss: 1.0032 - val_accuracy: 0.5628 - val_loss: 1.2710
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m67s[0m 179ms/step - accuracy: 0.6858 - loss: 0.8838 - val_accuracy: 0.5723 - val_loss: 1.3069
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m67s[0m 179ms/step - accuracy: 0.681

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m67s[0m 177ms/step - accuracy: 0.3201 - loss: 2.8198 - val_accuracy: 0.3356 - val_loss: 1.9083
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m66s[0m 177ms/step - accuracy: 0.5108 - loss: 1.3841 - val_accuracy: 0.5293 - val_loss: 1.3599
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m66s[0m 177ms/step - accuracy: 0.5616 - loss: 1.2418 - val_accuracy: 0.5591 - val_loss: 1.2604
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m66s[0m 176ms/step - accuracy: 0.5927 - loss: 1.1478 - val_accuracy: 0.5458 - val_loss: 1.2984
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m66s[0m 177ms/step - accuracy: 0.6143 - loss: 1.0698 - val_accuracy: 0.5704 - val_loss: 1.2363
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m66s[0m 177ms/step - accuracy: 0.617

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m68s[0m 180ms/step - accuracy: 0.4019 - loss: 2.6407 - val_accuracy: 0.4993 - val_loss: 1.6091
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 183ms/step - accuracy: 0.6103 - loss: 1.1038 - val_accuracy: 0.5432 - val_loss: 1.3210
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 191ms/step - accuracy: 0.6849 - loss: 0.9109 - val_accuracy: 0.5842 - val_loss: 1.2197
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 186ms/step - accuracy: 0.7447 - loss: 0.7357 - val_accuracy: 0.5655 - val_loss: 1.3891
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 187ms/step - accuracy: 0.8054 - loss: 0.5693 - val_accuracy: 0.5612 - val_loss: 1.6522
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 187ms/step - accuracy: 0.788

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 188ms/step - accuracy: 0.3655 - loss: 2.5755 - val_accuracy: 0.4544 - val_loss: 1.7289
Epoch 2/5


KeyboardInterrupt: 

In [7]:
input_shape = (32, 32, 3)
num_classes = 10

num_filters=[8, 16, 32, 64, 128]
learning_rates=[0.0005]
filter_sizes=[(3,3)]
dropout_rates=[0]

highest_config=[]
highest_accuracy=0
for learning_rate in learning_rates:
    highest_num_filter = 0
    highest_num_filter_accuracy=0
    best_num_filter = False
    for num_filter in num_filters:
        highest_filter_size = (0,0)
        highest_filter_size_accuracy=0
        best_filter_size = False
        for filter_size in filter_sizes:
            highest_dropout_rate = -1
            highest_dropout_rate_accuracy=0
            for dropout_rate in dropout_rates:
                current_config = [learning_rate, num_filter, filter_size, dropout_rate]
                model = build_simple_cnn(input_shape, num_classes, learning_rate, num_filter, filter_size, dropout_rate)

                X_combined_reshaped = X_combined.reshape(-1, 32, 32, 3)
                y_combined_categorical = to_categorical(y_combined, num_classes=10)
                X_combined_reshaped = X_combined_reshaped.astype('float32') / 255.0

                avg_scores = cross_validate_and_select_best_params(model, X_combined_reshaped, y_combined_categorical)
                print("current_config: ", current_config)
                print(avg_scores)

                if avg_scores['accuracy'] > highest_dropout_rate_accuracy:
                    highest_dropout_rate_accuracy = avg_scores['accuracy']
                    highest_dropout_rate = current_config[3]
                else:
                    print("Last dropout_rate is best, break")
                    break

                if avg_scores['accuracy'] >= highest_filter_size_accuracy:
                    highest_filter_size_accuracy = avg_scores['accuracy']
                    highest_filter_size = current_config[2]
                    best_filter_size = False
                elif highest_filter_size_accuracy > avg_scores['accuracy'] and filter_size > highest_filter_size:
                    print("Last filter_size is best, break")
                    best_filter_size = True

                if avg_scores['accuracy'] >= highest_num_filter_accuracy:
                    highest_num_filter_accuracy = avg_scores['accuracy']
                    highest_num_filter = current_config[1]
                    best_num_filter = False
                elif highest_num_filter_accuracy > avg_scores['accuracy'] and num_filter > highest_num_filter:
                    print("Last num_filter is best, break")
                    best_num_filter = True


                if avg_scores['accuracy'] > highest_accuracy:
                    highest_config = current_config
                    highest_accuracy = avg_scores['accuracy']
                    print("highest_config: ", highest_config)
                    print("highest_accuracy: ", highest_accuracy)

            if best_filter_size:
                break
        if best_num_filter:
            break
            
print("highest_config: ", highest_config)
print("highest_accuracy: ", highest_accuracy)

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.3952 - loss: 1.7335 - val_accuracy: 0.5036 - val_loss: 1.4665
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 7ms/step - accuracy: 0.5656 - loss: 1.2387 - val_accuracy: 0.5377 - val_loss: 1.3129
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 7ms/step - accuracy: 0.6258 - loss: 1.0728 - val_accuracy: 0.5402 - val_loss: 1.3355
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.6732 - loss: 0.9519 - val_accuracy: 0.5773 - val_loss: 1.2310
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.7062 - loss: 0.8485 - val_accuracy: 0.5780 - val_loss: 1.2565
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 943us/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - accuracy: 0.6999 - loss: 0.8698

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 14ms/step - accuracy: 0.4137 - loss: 1.7158 - val_accuracy: 0.5240 - val_loss: 1.4677
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 14ms/step - accuracy: 0.5972 - loss: 1.1388 - val_accuracy: 0.5787 - val_loss: 1.2086
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 14ms/step - accuracy: 0.6656 - loss: 0.9545 - val_accuracy: 0.5938 - val_loss: 1.1747
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 14ms/step - accuracy: 0.7189 - loss: 0.8130 - val_accuracy: 0.5881 - val_loss: 1.2325
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 14ms/step - accuracy: 0.7617 - loss: 0.6919 - val_accuracy: 0.5909 - val_loss: 1.2755
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 14ms/step - accuracy: 0.7519 - loss: 0.

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 33ms/step - accuracy: 0.4216 - loss: 1.7028 - val_accuracy: 0.5110 - val_loss: 1.5184
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 33ms/step - accuracy: 0.6207 - loss: 1.0925 - val_accuracy: 0.5846 - val_loss: 1.1708
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 33ms/step - accuracy: 0.6921 - loss: 0.8930 - val_accuracy: 0.6084 - val_loss: 1.1332
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 33ms/step - accuracy: 0.7438 - loss: 0.7483 - val_accuracy: 0.6153 - val_loss: 1.1556
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 33ms/step - accuracy: 0.7892 - loss: 0.6273 - val_accuracy: 0.6180 - val_loss: 1.1878
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 33ms/step - accuracy: 0.7681 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 85ms/step - accuracy: 0.4245 - loss: 1.7192 - val_accuracy: 0.5148 - val_loss: 1.6256
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 84ms/step - accuracy: 0.6229 - loss: 1.0893 - val_accuracy: 0.5916 - val_loss: 1.1556
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 84ms/step - accuracy: 0.6929 - loss: 0.8836 - val_accuracy: 0.6133 - val_loss: 1.1531
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 85ms/step - accuracy: 0.7427 - loss: 0.7526 - val_accuracy: 0.6174 - val_loss: 1.1588
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 84ms/step - accuracy: 0.7942 - loss: 0.6063 - val_accuracy: 0.6109 - val_loss: 1.2377
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 84ms/step - accuracy: 0.7771 - lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 189ms/step - accuracy: 0.4048 - loss: 1.9629 - val_accuracy: 0.5115 - val_loss: 1.7083
Epoch 2/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 183ms/step - accuracy: 0.6258 - loss: 1.0764 - val_accuracy: 0.5987 - val_loss: 1.1500
Epoch 3/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 185ms/step - accuracy: 0.7031 - loss: 0.8551 - val_accuracy: 0.6123 - val_loss: 1.1566
Epoch 4/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 186ms/step - accuracy: 0.7701 - loss: 0.6818 - val_accuracy: 0.6031 - val_loss: 1.2262
Epoch 5/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 187ms/step - accuracy: 0.8253 - loss: 0.5222 - val_accuracy: 0.6144 - val_loss: 1.2664
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step
Epoch 1/5
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 187ms/step - accuracy: 0.799

In [4]:
X_combined_reshaped = X_combined.reshape(-1, 32, 32, 3)
y_combined_categorical = to_categorical(y_combined, num_classes=10)
X_combined_reshaped = X_combined_reshaped.astype('float32') / 255.0

In [7]:
# GFNet simple
def global_filter_layer(x, filters, kernel_size=3, strides=1):
    # simple Conv2d
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding='same')(x)
    return x

# build GFNet model
def build_gfnet(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)

    x = global_filter_layer(inputs, filters=64, kernel_size=7, strides=2)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = global_filter_layer(x, filters=128, kernel_size=3)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(256, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs, x)
    return model

input_shape = (32, 32, 3)
num_classes = 10

model = build_gfnet(input_shape, num_classes)

model.compile(optimizer='adam', 
              loss='categorical_crossentropy', 
              metrics=['accuracy'])

model.summary()

history = model.fit(X_combined_reshaped, y_combined_categorical, 
                    epochs=10, 
                    batch_size=64, 
                    validation_split=0.1)

print(f"Training finished. Accuracy: {history.history['accuracy'][-1]}")

Epoch 1/10
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 81ms/step - accuracy: 0.3486 - loss: 1.7906 - val_accuracy: 0.3282 - val_loss: 2.0693
Epoch 2/10
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 82ms/step - accuracy: 0.4846 - loss: 1.4371 - val_accuracy: 0.2935 - val_loss: 2.5365
Epoch 3/10
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 82ms/step - accuracy: 0.5326 - loss: 1.3120 - val_accuracy: 0.3577 - val_loss: 1.7651
Epoch 4/10
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m68s[0m 81ms/step - accuracy: 0.5665 - loss: 1.2188 - val_accuracy: 0.3998 - val_loss: 1.9559
Epoch 5/10
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 82ms/step - accuracy: 0.5896 - loss: 1.1599 - val_accuracy: 0.4965 - val_loss: 1.4356
Epoch 6/10
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m68s[0m 81ms/step - accuracy: 0.6119 - loss: 1.0995 - val_accuracy: 0.5488 - val_loss: 1.2741
Epoch 7/10
[1m8

In [10]:
# 定义Swin Transformer模型
class SwinTransformerBlock(layers.Layer):
    def __init__(self, num_heads, window_size, mlp_ratio=4.0):
        super(SwinTransformerBlock, self).__init__()
        self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=64)
        self.mlp = models.Sequential([
            layers.Dense(int(64 * mlp_ratio), activation='gelu'),
            layers.Dense(64)
        ])
        self.norm1 = layers.LayerNormalization(epsilon=1e-5)
        self.norm2 = layers.LayerNormalization(epsilon=1e-5)

    def call(self, inputs):
        # Multi-head self-attention
        attn = self.attention(inputs, inputs)
        x = self.norm1(inputs + attn)  # Residual connection

        # MLP
        mlp_output = self.mlp(x)
        return self.norm2(x + mlp_output)  # Residual connection

class SwinTransformer(models.Model):
    def __init__(self, num_classes):
        super(SwinTransformer, self).__init__()
        self.conv = layers.Conv2D(64, kernel_size=(3, 3), padding='same', activation='relu')  # 增加卷积层
        self.block1 = SwinTransformerBlock(num_heads=4, window_size=4)
        self.block2 = SwinTransformerBlock(num_heads=4, window_size=4)
        self.flatten = layers.Flatten()
        self.dense = layers.Dense(num_classes, activation='softmax')

    def call(self, inputs):
        x = self.conv(inputs)  # 通过卷积层进行通道数调整
        x = tf.reshape(x, (-1, x.shape[1], x.shape[2], 64))  # 可能需要根据卷积后输出形状调整
        x = self.block1(x)
        x = self.block2(x)
        x = self.flatten(x)
        return self.dense(x)

model = SwinTransformer(num_classes=10)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# 训练模型
history = model.fit(X_combined_reshaped, y_combined_categorical, batch_size=64, epochs=10, validation_split=0.2)

# 保存模型
model.save('swin_transformer_cifar10.h5')
print(f"Training finished. Accuracy: {history.history['accuracy'][-1]}")

Epoch 1/10
[1m 14/750[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m54:02[0m 4s/step - accuracy: 0.1334 - loss: 29.3850

KeyboardInterrupt: 