In [None]:
import numpy as np
from sklearn.model_selection import KFold
import tensorflow as tf
import tensorflow_addons as tfa
from vit_keras import vit

# Define the image size
image_size = 224

# Define the number of classes
num_classes = 7

# Define the number of folds for cross-validation
num_folds = 5

# Create K-Fold cross-validator
kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

# Define an array to store test accuracies for each fold
test_accuracies = []

# Perform cross-validation
fold = 1
for train_index, val_index in kf.split(train_generator.filenames):
    print(f"Fold: {fold}")

    # Create the ViTb16 model for this fold
    vit_model = vit.vit_b16(
        image_size=image_size,
        activation='sigmoid',
        pretrained=True,
        pretrained_top=False,
        include_top=False,
        classes=num_classes
    )

    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(image_size, image_size, 3)),
        vit_model,
        tf.keras.layers.Flatten(),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dense(11, activation=tfa.activations.gelu),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dense(num_classes, activation='softmax')
    ],
        name='vision_transformer'
    )

    # Compile the model
    learning_rate = 1e-4
    optimizer = tfa.optimizers.RectifiedAdam(learning_rate=learning_rate)

    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.2),
        metrics=['accuracy']
    )

    # Print the summary of the model for this fold
    model.summary()

    # Define a callback to save the best model based on validation accuracy for this fold
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        filepath=f"best_model_vitb16_fold{fold}.h5",
        monitor="val_accuracy",
        verbose=1,
        save_best_only=True,
        mode="max"
    )

    # Define early stopping based on validation accuracy
    earlystopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        min_delta=1e-4,
        patience=5,
        mode='max',
        restore_best_weights=True,
        verbose=1
    )

    # Train the model for this fold
    epochs = 30
    train_data_gen, val_data_gen = None, None

    for epoch in range(epochs):
        print(f"Epoch: {epoch + 1}")

        # Initialize training and validation data generators for this epoch
        if train_data_gen is None:
            train_data_gen = train_generator
        else:
            train_data_gen.reset()
        
        if val_data_gen is None:
            val_data_gen = validation_generator
        else:
            val_data_gen.reset()

        # Train the model for this epoch using the current data generator
        model.fit(
            train_data_gen,
            epochs=1,
            validation_data=val_data_gen,
            callbacks=[checkpoint, earlystopping]
        )

    # Load the best saved model for this fold
    model.load_weights(f"best_model_vitb16_fold{fold}.h5")

    # Evaluate the model on the test set for this fold
    test_loss, test_accuracy = model.evaluate(test_generator)
    print(f"Test Accuracy (Fold {fold}):", test_accuracy)

    test_accuracies.append(test_accuracy)
    fold += 1

# Calculate and print the average accuracy
average_accuracy = np.mean(test_accuracies)
print("Average Test Accuracy:", average_accuracy)
