In [None]:
pip install patchify

Note: you may need to restart the kernel to use updated packages.


In [None]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model

class ClassToken(Layer):
    def __init__(self):
        super().__init__()

    def build(self, input_shape):
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            initial_value = w_init(shape=(1, 1, input_shape[-1]), dtype=tf.float32),
            trainable = True
        )

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        hidden_dim = self.w.shape[-1]

        cls = tf.broadcast_to(self.w, [batch_size, 1, hidden_dim])
        cls = tf.cast(cls, dtype=inputs.dtype)
        return cls

def mlp(x, cf):
    x = Dense(cf["mlp_dim"], activation="gelu")(x)
    x = Dropout(cf["dropout_rate"])(x)
    x = Dense(cf["hidden_dim"])(x)
    x = Dropout(cf["dropout_rate"])(x)
    return x

def transformer_encoder(x, cf):
    skip_1 = x
    x = LayerNormalization()(x)
    x = MultiHeadAttention(
        num_heads=cf["num_heads"], key_dim=cf["hidden_dim"]
    )(x, x)
    x = Add()([x, skip_1])

    skip_2 = x
    x = LayerNormalization()(x)
    x = mlp(x, cf)
    x = Add()([x, skip_2])

    return x

def ViT(cf):
    """ Inputs """
    input_shape = (cf["num_patches"], cf["patch_size"]*cf["patch_size"]*cf["num_channels"])
    inputs = Input(input_shape)     ## (None, 256, 3072)

    """ Patch + Position Embeddings """
    patch_embed = Dense(cf["hidden_dim"])(inputs)   ## (None, 256, 768)

    positions = tf.range(start=0, limit=cf["num_patches"], delta=1)
    pos_embed = Embedding(input_dim=cf["num_patches"], output_dim=cf["hidden_dim"])(positions) ## (256, 768)
    embed = patch_embed + pos_embed ## (None, 256, 768)

    """ Adding Class Token """
    token = ClassToken()(embed)
    x = Concatenate(axis=1)([token, embed]) ## (None, 257, 768)

    for _ in range(cf["num_layers"]):
        x = transformer_encoder(x, cf)

    """ Classification Head """
    x = LayerNormalization()(x)     ## (None, 257, 768)
    x = x[:, 0, :]
    x = Dense(cf["num_classes"], activation="softmax")(x)

    model = Model(inputs, x)
    return model


if __name__ == "__main__":
    config = {}
    config["num_layers"] = 3
    config["hidden_dim"] = 768
    config["mlp_dim"] = 3072
    config["num_heads"] = 12
    config["dropout_rate"] = 0.5 #from 0.1 to 0.5
    config["num_patches"] = 256 #from 64 to 256
    config["patch_size"] = 16 #from 32 to 16
    config["num_channels"] = 3
    config["num_classes"] = 88
    model = ViT(config)
    model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 256, 768)]   0           []                               
                                                                                                  
 dense_8 (Dense)                (None, 256, 768)     590592      ['input_2[0][0]']                
                                                                                                  
 tf.__operators__.add_1 (TFOpLa  (None, 256, 768)    0           ['dense_8[0][0]']                
 mbda)                                                                                            
                                                                                                  
 class_token_1 (ClassToken)     (None, 1, 768)       768         ['tf.__operators__.add_1[0]

In [None]:

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import numpy as np
import cv2
from glob import glob
from sklearn.utils import shuffle   #Need to disable real time scanning of McAfee software for this to run
# import random
from sklearn.model_selection import train_test_split
from patchify import patchify
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, ReduceLROnPlateau, EarlyStopping
#from vit import ViT

""" Hyperparameters """
""" Hyperparameters """
hp = {}
hp["image_size"] = 256
hp["num_channels"] = 3
hp["patch_size"] = 16 #from 32 to 16
hp["num_patches"] = (hp["image_size"]**2) // (hp["patch_size"]**2)
hp["flat_patches_shape"] = (hp["num_patches"], hp["patch_size"]*hp["patch_size"]*hp["num_channels"])

hp["batch_size"] = 16 #from 64 to 16
hp["lr"] = 2e-5 #from 1e-4
hp["num_epochs"] = 45
hp["num_classes"] = 88
hp["class_names"] = ['Apple__black_rot',
                     'Apple__healthy',
                     'Apple__rust',
                     'Apple__scab',
                     'Cassava__bacterial_blight',
                     'Cassava__brown_streak_disease',
                     'Cassava__green_mottle',
                     'Cassava__healthy',
                     'Cassava__mosaic_disease',
                     'Cherry__healthy',
                     'Cherry__powdery_mildew',
                     'Chili__healthy',
                     'Chili__leaf curl',
                     'Chili__leaf spot',
                     'Chili__whitefly',
                     'Chili__yellowish',
                     'Coffee__cercospora_leaf_spot',
                     'Coffee__healthy',
                     'Coffee__red_spider_mite',
                     'Coffee__rust',
                     'Corn__common_rust',
                     'Corn__gray_leaf_spot',
                     'Corn__healthy',
                     'Corn__northern_leaf_blight',
                     'Cucumber__diseased',
                     'Cucumber__healthy',
                     'Gauva__diseased',
                     'Gauva__healthy',
                     'Grape__black_measles',
                     'Grape__black_rot',
                     'Grape__healthy',
                     'Grape__leaf_blight_(isariopsis_leaf_spot)',
                     'Jamun__diseased',
                     'Jamun__healthy',
                     'Lemon__diseased',
                     'Lemon__healthy',
                     'Mango__diseased',
                     'Mango__healthy',
                     'Peach__bacterial_spot',
                     'Peach__healthy',
                     'Pepper_bell__bacterial_spot',
                     'Pepper_bell__healthy',
                     'Pomegranate__diseased',
                     'Pomegranate__healthy',
                     'Potato__early_blight',
                     'Potato__healthy',
                     'Potato__late_blight',
                     'Rice__brown_spot',
                     'Rice__healthy',
                     'Rice__hispa',
                     'Rice__leaf_blast',
                     'Rice__neck_blast',
                     'Soybean__bacterial_blight',
                     'Soybean__caterpillar',
                     'Soybean__diabrotica_speciosa',
                     'Soybean__downy_mildew',
                     'Soybean__healthy',
                     'Soybean__mosaic_virus',
                     'Soybean__powdery_mildew',
                     'Soybean__rust',
                     'Soybean__southern_blight',
                     'Strawberry___leaf_scorch',
                     'Strawberry__healthy',
                     'Sugarcane__bacterial_blight',
                     'Sugarcane__healthy',
                     'Sugarcane__red_rot',
                     'Sugarcane__red_stripe',
                     'Sugarcane__rust',
                     'Tea__algal_leaf',
                     'Tea__anthracnose',
                     'Tea__bird_eye_spot',
                     'Tea__brown_blight',
                     'Tea__healthy',
                     'Tea__red_leaf_spot',
                     'Tomato__bacterial_spot',
                     'Tomato__early_blight',
                     'Tomato__healthy',
                     'Tomato__late_blight',
                     'Tomato__leaf_mold',
                     'Tomato__mosaic_virus',
                     'Tomato__septoria_leaf_spot',
                     'Tomato__spider_mites_(two_spotted_spider_mite)',
                     'Tomato__target_spot',
                     'Tomato__yellow_leaf_curl_virus',
                     'Wheat__brown_rust',
                     'Wheat__healthy',
                     'Wheat__septoria',
                     'Wheat__yellow_rust']
hp["num_layers"] = 3
hp["hidden_dim"] = 768
hp["mlp_dim"] = 3072
hp["num_heads"] = 12
hp["dropout_rate"] = 0.5 #from 0.1


In [None]:

def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

# def load_data(path, split=0.1):
#     images = shuffle(glob(os.path.join(path, "*", "*.JPG" or ".*jpg")))

#     split_size = int(len(images) * split)
#     train_x, valid_x = train_test_split(images, test_size=split_size, random_state=42)
#     train_x, test_x = train_test_split(train_x, test_size=split_size, random_state=42)

#     return train_x, valid_x, test_x

def load_data(path, split=0.3):
    images = shuffle(glob(os.path.join(path, "*", "*.JPG" or ".*jpg")))
    # random.shuffle(images)

    print("Total number of images:", len(images))
    print("Sample images:", images[:5])  # Print first 5 image paths

    # Ensure there's at least one image for testing
    min_test_size = 1
    split_size = max(min_test_size, int(len(images) * split))
    print("Split size:", split_size)

    train_x, valid_test_x = train_test_split(images, test_size=split_size, random_state=42)
    valid_x, test_x = train_test_split(valid_test_x, test_size=0.33, random_state=42)

    return train_x, valid_x, test_x


def process_image_label(path):
    """ Reading images """
    path = path.decode()
    #print("Processing image path:", path)  # Print the path before processing
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.resize(image, (hp["image_size"], hp["image_size"]))
    image = image/255.0

    """ Preprocessing to patches """
    patch_shape = (hp["patch_size"], hp["patch_size"], hp["num_channels"])
    patches = patchify(image, patch_shape, hp["patch_size"])

    patches = np.reshape(patches, hp["flat_patches_shape"])
    patches = patches.astype(np.float32)

    """ Label """
    class_name = path.split("\\")[-2]  # Adjust to split using "\"
    class_idx = hp["class_names"].index(class_name)
    class_idx = np.array(class_idx, dtype=np.int32)

    return patches, class_idx

def parse(path):
    patches, labels = tf.numpy_function(process_image_label, [path], [tf.float32, tf.int32])
    labels = tf.one_hot(labels, hp["num_classes"])

    patches.set_shape(hp["flat_patches_shape"])
    labels.set_shape(hp["num_classes"])

    return patches, labels

def tf_dataset(images, batch=32):
    ds = tf.data.Dataset.from_tensor_slices((images))
    ds = ds.map(parse).batch(batch).prefetch(8)
    return ds


if __name__ == "__main__":
    """ Seeding """
    np.random.seed(42)
    tf.random.set_seed(42)

    """ Directory for storing files """
    create_dir("files")

    """ Paths """
    dataset_path = r"C:\Users\ankit\PDMD"
    model_path = os.path.join("files", "model_pdmd_final.h5")
    csv_path = os.path.join("files", "log.csv")

    """ Dataset """
    train_x, valid_x, test_x = load_data(dataset_path)
    print(f"Train: {len(train_x)} - Valid: {len(valid_x)} - Test: {len(test_x)}")

    train_ds = tf_dataset(train_x, batch=hp["batch_size"])
    valid_ds = tf_dataset(valid_x, batch=hp["batch_size"])

    """ Model """
    model = ViT(hp)
    model.compile(
        loss="categorical_crossentropy",
        optimizer=tf.keras.optimizers.Adam(hp["lr"], clipvalue=1.0),
        metrics=["acc"]
    )

    callbacks = [
        ModelCheckpoint(model_path, monitor='val_loss', verbose=1, save_best_only=True),
        ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, min_lr=1e-10, verbose=1),
        CSVLogger(csv_path),
        EarlyStopping(monitor='val_loss', patience=50, restore_best_weights=False),
    ]

    hist = model.fit(
        train_ds,
        epochs=hp["num_epochs"],
        validation_data=valid_ds,
        callbacks=callbacks
    )



    ## ...

Total number of images: 79083
Sample images: ['C:\\Users\\ankit\\PDMD\\Pomegranate__diseased\\Pomegranate__diseased_111.JPG', 'C:\\Users\\ankit\\PDMD\\Peach__bacterial_spot\\Peach__bacterial_spot_1355.JPG', 'C:\\Users\\ankit\\PDMD\\Soybean__healthy\\Soybean__healthy_4686.jpg', 'C:\\Users\\ankit\\PDMD\\Tomato__late_blight\\Tomato__late_blight_31.JPG', 'C:\\Users\\ankit\\PDMD\\Grape__leaf_blight_(isariopsis_leaf_spot)\\Grape__leaf_blight_(isariopsis_leaf_spot)_343.JPG']
Split size: 23724
Train: 55359 - Valid: 15895 - Test: 7829
Epoch 1/45
Epoch 1: val_loss improved from inf to 1.07465, saving model to files\model_pdmd_final.h5
Epoch 2/45
Epoch 2: val_loss improved from 1.07465 to 0.77875, saving model to files\model_pdmd_final.h5
Epoch 3/45
Epoch 3: val_loss improved from 0.77875 to 0.65184, saving model to files\model_pdmd_final.h5
Epoch 4/45
Epoch 4: val_loss improved from 0.65184 to 0.61453, saving model to files\model_pdmd_final.h5
Epoch 5/45
Epoch 5: val_loss improved from 0.61453 t

In [None]:
import matplotlib.pyplot as plt

fig=plt.figure()
plt.plot(hist.history['loss'], color='blue', label='loss')
plt.plot(hist.history['val_loss'], color='orange', label='val_loss')
fig.suptitle('Loss')
plt.legend(loc="upper left")
plt.show()

ModuleNotFoundError: No module named 'matplotlib'

In [None]:
fig = plt.figure()
plt.plot(hist.history['acc'], color ='blue', label='accuracy')
plt.plot(hist.history['val_acc'], color ='orange', label='val_accuracy')
fig.suptitle('Accuracy')
plt.legend(loc="upper left")
plt.show()

In [None]:
import numpy as np
from sklearn.metrics import confusion_matrix

# Load test data
test_ds = tf_dataset(test_x, batch=hp["batch_size"])

# Predict classes
y_true = []
y_pred = []

for images, labels in test_ds:
    y_true.extend(np.argmax(labels.numpy(), axis=1))
    y_pred.extend(np.argmax(model.predict(images), axis=1))

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:")
print(cm)


NameError: name 'tf_dataset' is not defined

In [None]:
from sklearn.metrics import accuracy_score

# Calculate overall test accuracy
test_accuracy = accuracy_score(y_true, y_pred)
print("Overall Test Accuracy:", test_accuracy)


In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score

# Calculate precision
precision = precision_score(y_true, y_pred, average='macro')
print("Precision:", precision)

# Calculate recall
recall = recall_score(y_true, y_pred, average='macro')
print("Recall:", recall)

# Calculate F1 score
f1 = f1_score(y_true, y_pred, average='macro')
print("F1 Score:", f1)


In [None]:
# Calculate class-wise metrics
class_accuracy = []
class_precision = []
class_recall = []
class_f1 = []

for i in range(len(hp["class_names"])):
    class_y_true = [1 if x == i else 0 for x in y_true]
    class_y_pred = [1 if x == i else 0 for x in y_pred]

    class_accuracy.append(accuracy_score(class_y_true, class_y_pred))
    class_precision.append(precision_score(class_y_true, class_y_pred))
    class_recall.append(recall_score(class_y_true, class_y_pred))
    class_f1.append(f1_score(class_y_true, class_y_pred))

for i, class_name in enumerate(hp["class_names"]):
    print(f"Class: {class_name}")
    print(f"Accuracy: {class_accuracy[i]}")
    print(f"Precision: {class_precision[i]}")
    print(f"Recall: {class_recall[i]}")
    print(f"F1 Score: {class_f1[i]}")
