# What is this?

The purpose of this notebook is to test loading & running the trained ViT model. The model is trained by Mitch w 10 epochs. Top 1 accuracy turns out to be 88%. 

In [1]:
import os
import numpy as np
import pandas as pd
import json
import cv2

import tensorflow as tf
from tensorflow import keras
import tensorflow_hub as hub

from keras import layers, models, optimizers, regularizers
from keras.applications import EfficientNetB0
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

# import matplotlib.pyplot as plt

In [2]:
# loads directory of 3 datasets
dir_train = pd.read_csv('dataset/EuroSAT/train.csv')
dir_valid = pd.read_csv('dataset/EuroSAT/validation.csv')
dir_test = pd.read_csv('dataset/EuroSAT/test.csv')

img_dir = pd.concat([dir_train, dir_valid], ignore_index=False)
img_dir = img_dir.iloc[:,1:-1].reset_index().drop(['index'], axis=1)
print(img_dir.shape)
# img_dir = img_dir.iloc[:100, :] # limit sample size when testing
img_dir

(24300, 3)


Unnamed: 0,Filename,Label,ClassName
0,AnnualCrop/AnnualCrop_142.jpg,0,AnnualCrop
1,HerbaceousVegetation/HerbaceousVegetation_2835...,2,HerbaceousVegetation
2,PermanentCrop/PermanentCrop_1073.jpg,6,PermanentCrop
3,Industrial/Industrial_453.jpg,4,Industrial
4,HerbaceousVegetation/HerbaceousVegetation_1810...,2,HerbaceousVegetation
...,...,...,...
24295,SeaLake/SeaLake_1943.jpg,9,SeaLake
24296,AnnualCrop/AnnualCrop_211.jpg,0,AnnualCrop
24297,Industrial/Industrial_1428.jpg,4,Industrial
24298,AnnualCrop/AnnualCrop_2571.jpg,0,AnnualCrop


In [4]:
# Load images and labels/classes
images = []
classes = []
# labels.typeof()

base_path = 'dataset/EuroSAT/'
for index, row in img_dir.iterrows():
    img_path = os.path.join(base_path, row['Filename'])
    img = cv2.imread(img_path)
    img = cv2.resize(img, (64, 64))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    images.append(img)
    classes.append(row['ClassName'])

# Normalize images
images = np.array(images) / 255.0
labels = pd.get_dummies(classes).values

In [5]:
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=0.2, random_state=42)

# Set up the data augmentation
train_data_generator = ImageDataGenerator(
    rotation_range=180,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    vertical_flip=True
)
valid_data_generator = ImageDataGenerator()

In [6]:
import tensorflow_addons as tfa
from tensorflow.keras.models import load_model

custom_objects = {
    "AdamW": tfa.optimizers.AdamW
}

model_1 = load_model("model_vit_classifier", custom_objects=custom_objects)


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 





In [7]:
# Evaluate the model
batch_size = 16  # Choose a smaller batch size according to your GPU memory capacity
y_pred_batches = []

for i in range(0, len(X_test), batch_size):
    batch_pred = model_1.predict(X_test[i:i + batch_size])
    y_pred_batches.append(batch_pred)

y_pred = np.concatenate(y_pred_batches, axis=0)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true = np.argmax(y_test, axis=1)

# y_pred = model_1.predict(X_test)
# y_pred_classes = np.argmax(y_pred, axis=1)
# y_true = np.argmax(y_test, axis=1)



In [8]:
print(classification_report(y_true, y_pred_classes))
print(confusion_matrix(y_true, y_pred_classes))

              precision    recall  f1-score   support

           0       0.91      0.89      0.90       551
           1       0.94      0.94      0.94       547
           2       0.85      0.88      0.86       535
           3       0.80      0.77      0.79       459
           4       0.90      0.95      0.92       453
           5       0.83      0.83      0.83       371
           6       0.80      0.84      0.82       456
           7       0.89      0.97      0.92       502
           8       0.91      0.75      0.83       469
           9       0.97      0.95      0.96       517

    accuracy                           0.88      4860
   macro avg       0.88      0.88      0.88      4860
weighted avg       0.88      0.88      0.88      4860

[[491   2   3   5   1  11  30   1   6   1]
 [  0 515   8   0   0  18   0   0   0   6]
 [  7   4 469   4   5   5  23  10   2   6]
 [  2   0   4 355  30   8  11  35  14   0]
 [  0   0   2   4 432   0   0  15   0   0]
 [ 11   4   9   8   0 308 