In [1]:
import os
import glob
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, Input
from tensorflow.keras.callbacks import EarlyStopping, Callback
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from sklearn.utils import class_weight
import matplotlib.pyplot as plt
import tensorflow_hub as hub

In [9]:
# === CONFIG ===
SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)

BASE_PATH = r"C:\Users\ADITYA DAS\Desktop\Machine Learning\CP_DATASET"
CLASSES = ["BLIGHT", "BLAST", "BROWNSPOT", "HEALTHY"]
IMG_SIZE = (224, 224) # Common size for ViT-B/16
BATCH_SIZE = 4
EPOCHS = 2

In [10]:
# === Load filepaths & labels ===
all_filepaths, all_labels = [], []
for idx, class_name in enumerate(CLASSES):
    aug_path = os.path.join(BASE_PATH, class_name, "augmented")
    files = glob.glob(os.path.join(aug_path, "*.jpg")) + \
            glob.glob(os.path.join(aug_path, "*.jpeg")) + \
            glob.glob(os.path.join(aug_path, "*.png"))
    all_filepaths.extend(files)
    all_labels.extend([idx] * len(files))

all_filepaths = np.array(all_filepaths)
all_labels = np.array(all_labels)

In [11]:
# === Create datasets ===
def load_and_preprocess_image(filepath, label):
    img = tf.io.read_file(filepath)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE)
    img = img / 255.0  # Normalize to [0, 1] for ViT
    return img, tf.one_hot(label, len(CLASSES))

dataset = tf.data.Dataset.from_tensor_slices((all_filepaths, all_labels))
dataset = dataset.shuffle(buffer_size=1024, seed=SEED)

train_size = int(0.8 * len(all_filepaths))
val_size = int(0.1 * len(all_filepaths))
test_size = len(all_filepaths) - train_size - val_size

train_ds = dataset.take(train_size).map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_ds = dataset.skip(train_size).take(val_size).map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_ds = dataset.skip(train_size + val_size).map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print(f"Dataset sizes: Train={len(train_ds)*BATCH_SIZE}, Val={len(val_ds)*BATCH_SIZE}, Test={len(test_ds)*BATCH_SIZE}")

Dataset sizes: Train=19204, Val=2400, Test=2404


In [7]:
# === Build Model (Phase 1: Feature Extractor Frozen) ===
# Use the new working ViT model from TensorFlow Hub
vit_model_handle = "https://www.kaggle.com/models/spsayakpaul/vision-transformer/TensorFlow2/vit-b16-classification/1" # Corrected ViT feature extractor URL
# model = tf.keras.Sequential([
#     hub.KerasLayer("https://www.kaggle.com/models/spsayakpaul/vision-transformer/TensorFlow2/vit-b16-classification/1")
# ])

# Input layer for the images
inputs = Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))

# Load the ViT feature extractor
# trainable=False freezes the weights of the ViT backbone
vit_backbone = hub.KerasLayer(vit_model_handle, trainable=False, name='vit_feature_extractor')(inputs)

# Add a classification head on top of the frozen ViT features
x = Dropout(0.5)(vit_backbone)
outputs = Dense(len(CLASSES), activation='softmax')(x)

model = Model(inputs=inputs, outputs=outputs)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), # Higher LR for new layers
    loss='categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=3, name='top3_accuracy')]
)

model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 vit_feature_extractor (Kera  (None, 1000)             86567656  
 sLayer)                                                         
                                                                 
 dropout (Dropout)           (None, 1000)              0         
                                                                 
 dense (Dense)               (None, 4)                 4004      
                                                                 
Total params: 86,571,660
Trainable params: 4,004
Non-trainable params: 86,567,656
_________________________________________________________________


In [12]:
# === Learning Rate Logger Callback ===
class LearningRateLogger(Callback):
    def on_epoch_end(self, epoch, logs=None):
        lr = self.model.optimizer.lr
        if hasattr(lr, '__call__'):
            lr = lr(self.model.optimizer.iterations)
        if hasattr(lr, 'numpy'):
            lr = lr.numpy()
        print(f"📉 Learning rate at epoch {epoch+1}: {lr:.6f}")

# === Compute class weights ===
y_train_int = np.argmax(np.concatenate([labels.numpy() for _, labels in train_ds.unbatch().batch(BATCH_SIZE)]), axis=1)
class_weights = dict(enumerate(class_weight.compute_class_weight(
    class_weight='balanced',
    classes=np.arange(len(CLASSES)),
    y=y_train_int
)))
print("✅ Computed class weights:", class_weights)

✅ Computed class weights: {0: 1.3298476454293628, 1: 0.7029945819300044, 2: 0.8826530612244898, 3: 1.4438345864661655}


In [13]:
# === Train ===
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=[EarlyStopping(patience=4, restore_best_weights=True), LearningRateLogger()],
    class_weight=class_weights
)

Epoch 1/2


InternalError: Graph execution error:

Detected at node 'model/vit_feature_extractor/StatefulPartitionedCall' defined at (most recent call last):
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\ipykernel_launcher.py", line 18, in <module>
      app.launch_new_instance()
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
      app.start()
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\ipykernel\kernelapp.py", line 739, in start
      self.io_loop.start()
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\tornado\platform\asyncio.py", line 211, in start
      self.asyncio_loop.run_forever()
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\asyncio\base_events.py", line 601, in run_forever
      self._run_once()
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\asyncio\base_events.py", line 1905, in _run_once
      handle._run()
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\asyncio\events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue
      await self.process_one()
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\ipykernel\kernelbase.py", line 534, in process_one
      await dispatch(*args)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell
      await result
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\ipykernel\ipkernel.py", line 362, in execute_request
      await super().execute_request(stream, ident, parent)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\ipykernel\kernelbase.py", line 778, in execute_request
      reply_content = await reply_content
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\ipykernel\ipkernel.py", line 449, in do_execute
      res = shell.run_cell(
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell
      return super().run_cell(*args, **kwargs)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\IPython\core\interactiveshell.py", line 3024, in run_cell
      result = self._run_cell(
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\IPython\core\interactiveshell.py", line 3079, in _run_cell
      result = runner(coro)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\IPython\core\interactiveshell.py", line 3284, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\IPython\core\interactiveshell.py", line 3466, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\IPython\core\interactiveshell.py", line 3526, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\ADITYA DAS\AppData\Local\Temp\ipykernel_12996\3391820251.py", line 2, in <module>
      history = model.fit(
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\engine\training.py", line 1564, in fit
      tmp_logs = self.train_function(iterator)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\engine\training.py", line 1160, in train_function
      return step_function(self, iterator)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\engine\training.py", line 1146, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\engine\training.py", line 1135, in run_step
      outputs = model.train_step(data)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\engine\training.py", line 993, in train_step
      y_pred = self(x, training=True)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\engine\training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\engine\functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\engine\functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\tensorflow_hub\keras_layer.py", line 233, in call
      if not self._has_training_argument:
    File "C:\Users\ADITYA DAS\.conda\envs\tf2.10.1\lib\site-packages\tensorflow_hub\keras_layer.py", line 234, in call
      result = f()
Node: 'model/vit_feature_extractor/StatefulPartitionedCall'
libdevice not found at ./libdevice.10.bc
	 [[{{node model/vit_feature_extractor/StatefulPartitionedCall}}]] [Op:__inference_train_function_41131]

In [1]:
# === Evaluate ===
y_true, y_pred = [], []
for images, labels in val_ds:
    preds = model.predict(images)
    y_pred.extend(np.argmax(preds, axis=1))
    y_true.extend(np.argmax(labels.numpy(), axis=1))

print("\n📊 Classification Report:")
print(classification_report(y_true, y_pred, target_names=CLASSES))

# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=CLASSES)
disp.plot(cmap=plt.cm.Blues)
plt.title('Confusion Matrix - Phase 1 (ViT)')
plt.show()

# === Save ===
SAVE_PATH = r"C:\Users\ADITYA DAS\Desktop\Machine Learning\CP_MODEL\ViT_Phase1_CutMix_GridMask.h5"
model.save(SAVE_PATH)
print(f"✅ Model saved at: {SAVE_PATH}")