In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import shutil

In [None]:
!pip install tensorflow-addons
!pip install vit_keras

In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import load_img, save_img
from tensorflow.keras.models import load_model
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, TensorBoard
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from vit_keras import vit

# Obtain Image

In [None]:
#Load Image by Numpy Array
#Train Image
train_image_npy_path = './train_image.npy' #train_image_npy_path
train_image = np.load(train_image_npy_path)
#Test Image
test_image_npy_path = './test_image.npy' #test_image_npy_path
test_image = np.load(test_image_npy_path)

# Obtain Label

Three-Category Dataset

In [None]:
#Train Label
train_label_npy_path = './train_label_3.npy' #train_label_npy_path (Three-Category Dataset)
train_label = np.load(train_label_npy_path)
#Test Label
test_label_npy_path = './test_label_3.npy' #test_label_npy_path (Three-Category Dataset)
test_label = np.load(test_label_npy_path)

Nine-Class Dataset

In [None]:
#Train Label
train_label_npy_path = './train_label_9.npy' #train_label_npy_path (Nine-Class Dataset)
train_label = np.load(train_label_npy_path)
#Test Label
test_label_npy_path = './test_label_9.npy' #test_label_npy_path (Nine-Class Dataset)
test_label = np.load(test_label_npy_path)

# Model

In [None]:
img_size = 128
classes = 9

### Model Construction

In [None]:
!pip install -U kecam

SwimT

In [None]:
from keras_cv_attention_models import swin_transformer_v2

pretrained = os.path.expanduser('./swin_transformer_v2_base_window16_256_imagenet.h5') #swin_transformer_v2 Model Path
base_model = swin_transformer_v2.SwinTransformerV2Base_window16(
    input_shape = (img_size, img_size, 3),
    pretrained=pretrained,
    num_classes=classes
)

Model = tf.keras.Sequential([
        base_model
    ],
    name = 'swim_transformer')

Model.summary()

DaViT

In [None]:
from keras_cv_attention_models import davit

pretrained = os.path.expanduser('./davit_b_imagenet.h5') #DaViT Model Path
base_model = davit.DaViT_B(
    input_shape = (img_size, img_size, 3),
    pretrained=pretrained,
    num_classes=classes
)

Model = tf.keras.Sequential([
        base_model
    ],
    name = 'DaViT')

Model.summary()

Train the output layers

In [None]:
base_model.trainable = True

# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 777

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

Model.summary()

Compile Model

In [None]:
Model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss=keras.losses.CategoricalCrossentropy(),
    metrics=[tf.keras.metrics.CategoricalAccuracy()]
)

Call Backs

In [None]:
checkpoint_filepath = '' #checkpoint_filepath

lr = ReduceLROnPlateau(monitor='val_categorical_accuracy', min_delta=0, factor=0.5, patience=2, min_lr=1e-6)
stop = EarlyStopping(monitor="val_categorical_accuracy", min_delta=0, patience=5, restore_best_weights=True)
ck = ModelCheckpoint(filepath=checkpoint_filepath, save_weights_only=False , monitor='val_categorical_accuracy', mode='max', save_best_only=True)

model_callbacks = [lr, ck]

Train Model

In [None]:
EPOCHS = 20
BATCH_SIZE = 32

Model_history = Model.fit(train_image, train_label, validation_split=0.1, epochs=EPOCHS, batch_size=BATCH_SIZE, shuffle=True, verbose=1, callbacks=model_callbacks)

Plot Accuaracy & Loss

In [None]:
#Acc
acc = Model_history.history['categorical_accuracy']
val_acc = Model_history.history['val_categorical_accuracy']

#Loss
loss = Model_history.history['loss']
val_loss = Model_history.history['val_loss']

#Acc figure
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.plot(acc, label='Train Accuracy')
plt.plot(val_acc, label='Validation Accuracy', linestyle='--')
plt.title('Accuracy', fontsize = 16)
plt.xlabel('Epoch', fontsize = 14)
plt.ylabel('Acc', fontsize = 14)
plt.legend()

#Loss figure
plt.subplot(122)
plt.plot(loss, label='Train loss')
plt.plot(val_loss, label='Validation loss', linestyle='--')
plt.title('Loss', fontsize = 16)
plt.xlabel('Epoch', fontsize = 14)
plt.ylabel('Loss', fontsize = 14)
plt.legend()
sns.set(style='darkgrid')
plt.show()

Model Evaluatoion

In [None]:
score = Model.evaluate(test_image, test_label, verbose=1)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

### Fine-tune Model

In [None]:
base_model.trainable = True

# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 560

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

Model.summary()

Compile Model

In [None]:
Model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss=keras.losses.CategoricalCrossentropy(),
    metrics=[tf.keras.metrics.CategoricalAccuracy()]
)

Call Backs

In [None]:
checkpoint_filepath = '' #checkpoint_filepath

lr = ReduceLROnPlateau(monitor='val_categorical_accuracy', min_delta=0, factor=0.5, patience=2, min_lr=1e-6)
stop = EarlyStopping(monitor="val_categorical_accuracy", min_delta=0, patience=5, restore_best_weights=True)
ck = ModelCheckpoint(filepath=checkpoint_filepath, save_weights_only=False , monitor='val_categorical_accuracy', mode='max', save_best_only=True)

model_callbacks = [lr, ck]

Train Model

In [None]:
FINETUNE_EPOCHS = 10
INITIAL_EPOCHS = Model_history.epoch[-1]+1
TOTAL_EPOCHS =  INITIAL_EPOCHS + FINETUNE_EPOCHS

Model_history_fine = Model.fit(train_image, train_label, validation_split=0.25, epochs=TOTAL_EPOCHS, initial_epoch=INITIAL_EPOCHS, batch_size=BATCH_SIZE, shuffle=True, verbose=1, callbacks=model_callbacks)

In [None]:
score = Model.evaluate(test_image, test_label, verbose=1)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

Save Model

In [None]:
save_checkpoint_filepath = '' #save_checkpoint_filepath
Model.save(save_checkpoint_filepath)

Load Model

In [None]:
load_checkpoint_filepath = '' #load_checkpoint_filepath
Model = load_model(checkpoint_filepath)

Model Prediction

In [None]:
test_pred = np.argmax(test_pred, axis=1)
test_label_true = np.argmax(test_label, axis=1)

Classification Report

In [None]:
from sklearn.metrics import classification_report

target_names_three = ['0(text)', '1(targer map)', '2(non-target graphic)']
target_names_nine = ['0(text)', '1(scenic map)', '2(city map)', '3(administrative map)', '4(star map)', '5(photograph)', '6(human figure)', '7(building)', '8(object)']

print('\nClassification Report\n')
print(classification_report(test_label_true, test_pred, target_names=target_names_nine, digits=4))

Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(test_label_true, test_pred)
display_labels_three=['0(text)', '1(target map)', '2(non-target graphic)']
display_labels_nine=['0(text)', '1(scenic)', '2(city)', '3(adm)', '4(star)', '5(photo)', '6(human)', '7(building)', '8(object)']
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=display_labels_nine)
fig, ax = plt.subplots(figsize=(15,15))
plt.title('ConfusionMatrix', fontsize=20)
disp.plot(cmap=plt.cm.Blues, values_format='g', ax=ax)
sns.set(style='white')
plt.show()

Error Example

In [None]:
def saveImage(a,b,i):
    error_example_dir = '' #error_example_dir
    dir = error_example_dir + str(a) + "_" + str(b) + "/"

    if not os.path.exists(dir):
        os.makedirs(dir)
    
    test_label_dir = ''
    source = test_label_dir + str(10001+i) + ".jpg"
    destination = dir + str(10001+i) + ".jpg"
    shutil.copyfile(source,destination)

In [None]:
for i in range(0,len(test_label_true)):
    a = test_label_true[i]
    b = test_pred[i]
    if a != b:
        saveImage(a,b,i)