In [None]:
! pip install transformers tensorflow 


In [None]:
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import numpy as np
import os

# Directory containing all images
data_dir = "path_to_data"

# Lists to hold images and labels
images = []
labels = []

# Get all the folder names (leaf types)
leaf_types = [folder_name for folder_name in os.listdir(data_dir)]

# Create a mapping of leaf types to integer labels
label_mapping = {leaf_type: idx for idx, leaf_type in enumerate(leaf_types)}

# Load images and labels
for folder_name in leaf_types:
    folder_path = os.path.join(data_dir, folder_name)
    for img_name in os.listdir(folder_path):
        img_path = os.path.join(folder_path, img_name)
        img = load_img(img_path, target_size=(224, 224))
        img_array = img_to_array(img)
        images.append(img_array)
        labels.append(label_mapping[folder_name])

# Convert to arrays
images = np.array(images)
labels = to_categorical(np.array(labels), num_classes=len(label_mapping))

# Split into training, validation, and testing sets (e.g., 70% train, 15% validation, 15% test)
train_images, test_images, train_labels, test_labels = train_test_split(images, labels, test_size=0.3, random_state=42)
val_images, test_images, val_labels, test_labels = train_test_split(test_images, test_labels, test_size=0.5, random_state=42)


In [None]:
# Data augmentation
train_datagen = ImageDataGenerator(rescale=1.0/255, rotation_range=20, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
val_datagen = ImageDataGenerator(rescale=1.0/255)
test_datagen = ImageDataGenerator(rescale=1.0/255)

# Create data generators
train_generator = train_datagen.flow(train_images, train_labels, batch_size=32)
val_generator = val_datagen.flow(val_images, val_labels, batch_size=32)
test_generator = test_datagen.flow(test_images, test_labels, batch_size=32)


### Model loading

In [None]:
from transformers import DeiTFeatureExtractor, DeiTForImageClassification

# Load the feature extractor
feature_extractor = DeiTFeatureExtractor.from_pretrained("facebook/deit-base-distilled-patch16-224")

# Load the DeiT model
model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")


### model training

In [None]:
from tensorflow.keras.optimizers import Adam

# Compile the model
model.compile(optimizer=Adam(lr=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(train_generator, validation_data=val_generator, epochs=10, steps_per_epoch=100, validation_steps=50)


### Evaluation

In [None]:
# Evaluate the model
test_loss, test_accuracy = model.evaluate(test_generator)
print("Test accuracy:", test_accuracy)


### interpretations

In [None]:
import matplotlib.pyplot as plt

# Plot training & validation accuracy values
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()


### Testing

In [None]:
# Evaluate the model
test_loss, test_accuracy = model.evaluate(test_generator)
print("Test accuracy:", test_accuracy)

In [None]:
import numpy as np

# Get a batch of test images and labels
test_images_batch, test_labels_batch = next(test_generator)

# Predict the labels
predicted_labels_batch = model.predict(test_images_batch)
predicted_labels = np.argmax(predicted_labels_batch, axis=1)
true_labels = np.argmax(test_labels_batch, axis=1)

# Reverse the label mapping to get the leaf type names
reverse_label_mapping = {value: key for key, value in label_mapping.items()}

# Plot some images with true and predicted labels
for i in range(5): # Change this to the number of images you want to visualize
    plt.imshow(test_images_batch[i])
    plt.title(f"True: {reverse_label_mapping[true_labels[i]]}, Predicted: {reverse_label_mapping[predicted_labels[i]]}")
    plt.show()
