In [1]:
import os
import shutil

SOURCE_BASE = "/kaggle/input/plant-village-dataset-updated"

DEST_BASE = "/kaggle/working/plant-village-merged"

plants = [
    "Apple",
    "Bell Pepper",
    "Cherry",
    "Corn (Maize)",
    "Grape",
    "Peach",
    "Potato",
    "Strawberry",
    "Tomato"
]

splits = ["Train", "Val", "Test"]


for split in splits:
    os.makedirs(os.path.join(DEST_BASE, split), exist_ok=True)


for plant in plants:
    for split in splits:
        plant_split_dir = os.path.join(SOURCE_BASE, plant, split)
        if not os.path.isdir(plant_split_dir):
            continue
        
        # Each disease folder inside Apple/Train, Apple/Val, etc.
        for disease in os.listdir(plant_split_dir):
            disease_folder = os.path.join(plant_split_dir, disease)
            if not os.path.isdir(disease_folder):
                continue
            
            # Make a combined class name to avoid collisions:
            # e.g., "Apple_Apple Scab"
            combined_class_name = f"{plant}_{disease}"
            
            # Create the destination subfolder
            dest_class_dir = os.path.join(DEST_BASE, split, combined_class_name)
            os.makedirs(dest_class_dir, exist_ok=True)
            
            # Copy all images
            for img_name in os.listdir(disease_folder):
                src_path = os.path.join(disease_folder, img_name)

                if os.path.isdir(src_path):
                    continue
                
                dst_path = os.path.join(dest_class_dir, img_name)
                # Copy or move the image (copy2 preserves metadata)
                shutil.copy2(src_path, dst_path)

print("Merging complete!")


import os

def count_images_in_folder(folder_path):
    """Count the number of image files in a folder (recursively)."""
    valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')
    count = 0
    for root, dirs, files in os.walk(folder_path):
        for filename in files:
            if filename.lower().endswith(valid_extensions):
                count += 1
    return count

# Paths to your flattened dataset splits
train_path = '/kaggle/working/plant-village-merged/Train'
val_path   = '/kaggle/working/plant-village-merged/Val'
test_path  = '/kaggle/working/plant-village-merged/Test'

# Count images in each split
train_count = count_images_in_folder(train_path)
val_count   = count_images_in_folder(val_path)
test_count  = count_images_in_folder(test_path)

print("Total training images:  ", train_count)
print("Total validation images:", val_count)
print("Total testing images:   ", test_count)

Merging complete!
Total training images:   53690
Total validation images: 12067
Total testing images:    1354


In [2]:
import os
import numpy as np
import tensorflow as tf
from collections import Counter

# -----------------------------
# DATASET SETUP
# -----------------------------
dataset_path = '/kaggle/working/plant-village-merged'
train_path = os.path.join(dataset_path, "Train")
val_path   = os.path.join(dataset_path, "Val")
test_path  = os.path.join(dataset_path, "Test")

# Function to get file paths and labels from the directory structure.
def get_file_paths_and_labels(train_path):
    classes = sorted(os.listdir(train_path))
    file_paths = []
    labels = []
    for idx, cls in enumerate(classes):
        class_dir = os.path.join(train_path, cls)
        if os.path.isdir(class_dir):
            for file in os.listdir(class_dir):
                if os.path.isfile(os.path.join(class_dir, file)):
                    file_paths.append(os.path.join(class_dir, file))
                    labels.append(idx)
    return file_paths, labels, classes

# Get the file paths and labels
file_paths, labels, classes = get_file_paths_and_labels(train_path)

# Check counts per class (using numeric labels)
counts = Counter(labels)
print("Original class counts:")
for cls_idx, count in counts.items():
    print(f"{classes[cls_idx]}: {count} images")

# Determine the maximum number of images among all classes
max_count = max(counts.values())

# Oversample: for each class, randomly choose additional indices to match the max count.
balanced_file_paths = list(file_paths)
balanced_labels = list(labels)

for cls in range(len(classes)):
    # Get indices for the current class
    cls_indices = [i for i, label in enumerate(labels) if label == cls]
    current_count = counts[cls]
    
    # Number of samples to add
    n_to_add = max_count - current_count
    if n_to_add > 0:
        # Randomly sample indices with replacement from current indices
        oversampled_indices = np.random.choice(cls_indices, size=n_to_add, replace=True)
        balanced_file_paths.extend([file_paths[i] for i in oversampled_indices])
        balanced_labels.extend([labels[i] for i in oversampled_indices])

# Verify new counts
balanced_counts = Counter(balanced_labels)
print("\nBalanced class counts:")
for cls_idx, count in balanced_counts.items():
    print(f"{classes[cls_idx]}: {count} images")

# Create a TensorFlow dataset from the balanced file paths and labels.
def load_and_preprocess_image(path, label, img_size=(128, 128)):
    # Read image file
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, img_size)
    # Normalize the image to [0,1]
    image = image / 255.0
    return image, label

balanced_ds = tf.data.Dataset.from_tensor_slices((balanced_file_paths, balanced_labels))
balanced_ds = balanced_ds.shuffle(buffer_size=len(balanced_file_paths))
balanced_ds = balanced_ds.map(lambda path, label: load_and_preprocess_image(path, label),
                              num_parallel_calls=tf.data.AUTOTUNE)
balanced_ds = balanced_ds.batch(32).prefetch(tf.data.AUTOTUNE)


# Path where the new balanced dataset will be saved
balanced_dataset_path = "/kaggle/working/balanced-plant-village"

# Create the new dataset directory
os.makedirs(balanced_dataset_path, exist_ok=True)

for cls in classes:
    os.makedirs(os.path.join(balanced_dataset_path, cls), exist_ok=True)

# Copy original images to new dataset
for src_path, label in zip(file_paths, labels):
    class_name = classes[label]
    dst_path = os.path.join(balanced_dataset_path, class_name, os.path.basename(src_path))
    shutil.copy(src_path, dst_path)

# Oversampling: copy additional images to balance dataset
for cls in range(len(classes)):
    cls_indices = [i for i, label in enumerate(labels) if label == cls]
    current_count = counts[cls]
    n_to_add = max_count - current_count
    
    if n_to_add > 0:
        oversampled_indices = np.random.choice(cls_indices, size=n_to_add, replace=True)
        for idx in oversampled_indices:
            src_path = file_paths[idx]
            class_name = classes[cls]
            filename = f"aug_{np.random.randint(10000)}_{os.path.basename(src_path)}"  # Rename to avoid conflicts
            dst_path = os.path.join(balanced_dataset_path, class_name, filename)
            shutil.copy(src_path, dst_path)

print(f"Balanced dataset saved at: {balanced_dataset_path}")

2025-04-22 15:00:15.170518: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745334015.426349      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745334015.496737      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Original class counts:
Apple_Apple Scab: 2016 images
Apple_Black Rot: 1987 images
Apple_Cedar Apple Rust: 1760 images
Apple_Healthy: 2008 images
Bell Pepper_Bacterial Spot: 1913 images
Bell Pepper_Healthy: 1988 images
Cherry_Healthy: 1826 images
Cherry_Powdery Mildew: 1683 images
Corn (Maize)_Cercospora Leaf Spot: 1642 images
Corn (Maize)_Common Rust : 1907 images
Corn (Maize)_Healthy: 1859 images
Corn (Maize)_Northern Leaf Blight: 1908 images
Grape_Black Rot: 1888 images
Grape_Esca (Black Measles): 1920 images
Grape_Healthy: 1692 images
Grape_Leaf Blight: 1722 images
Peach_Bacterial Spot: 1838 images
Peach_Healthy: 1728 images
Potato_Early Blight: 1939 images
Potato_Healthy: 1824 images
Potato_Late Blight: 1939 images
Strawberry_Healthy: 1824 images
Strawberry_Leaf Scorch: 1774 images
Tomato_Bacterial Spot: 1702 images
Tomato_Early Blight: 1920 images
Tomato_Healthy: 1926 images
Tomato_Late Blight: 1851 images
Tomato_Septoria Leaf Spot: 1745 images
Tomato_Yellow Leaf Curl Virus: 1961 

I0000 00:00:1745334029.414304      31 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1745334029.414954      31 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


Balanced dataset saved at: /kaggle/working/balanced-plant-village


In [10]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from collections import Counter
from sklearn.model_selection import StratifiedKFold
from tensorflow.keras import layers, Model, Input, regularizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# -----------------------------
# 0. SETTINGS
# -----------------------------
SEED       = 42
IMG_SIZE   = (128,128)
BATCH_SIZE = 32
EPOCHS     = 50
N_SPLITS   = 5
PROJ_DIM   = 1280
INITIAL_LR = 1e-3

# Paths (after merging & oversampling into disk)
BALANCED_DIR = "/kaggle/working/balanced-plant-village"

# -----------------------------
# 1. GATHER ALL FILE PATHS & LABELS
# -----------------------------
classes = sorted(os.listdir(BALANCED_DIR))
file_paths = []
labels     = []
for idx, cls in enumerate(classes):
    cls_dir = os.path.join(BALANCED_DIR, cls)
    for fname in os.listdir(cls_dir):
        if fname.lower().endswith(('.jpg','jpeg','png')):
            file_paths.append(os.path.join(cls_dir, fname))
            labels.append(idx)

file_paths = np.array(file_paths)
labels     = np.array(labels)
print("Total images:", len(file_paths))
print("Classes:", classes)

# -----------------------------
# 2. HELPERS TO BUILD MODELS
# -----------------------------
def build_teacher_model(base_fn):
    base_model = base_fn(input_shape=(*IMG_SIZE,3), include_top=False, pooling='avg', weights='imagenet')
    inp = Input(shape=(*IMG_SIZE,3))
    base = base_model(inp)
    proj = layers.Dense(PROJ_DIM, activation='relu', name='proj_features')(base)
    pred = layers.Dense(len(classes), activation='softmax')(proj)
    return Model(inp, [proj, pred])

def squeeze_excite_block(x, ratio=16):
    f = x.shape[-1]
    se = layers.GlobalAveragePooling2D()(x)
    se = layers.Reshape((1,1,f))(se)
    se = layers.Dense(f//ratio, activation='relu')(se)
    se = layers.Dense(f, activation='sigmoid')(se)
    return layers.multiply([x,se])

def build_student_model():
    inp = Input((*IMG_SIZE,3))
    x = layers.Conv2D(32,3,padding='same',activation='relu',
                      kernel_regularizer=regularizers.l2(1e-4))(inp)
    x = layers.BatchNormalization()(x); x = layers.MaxPool2D()(x)

    x = layers.SeparableConv2D(64,3,padding='same',activation='relu',
        depthwise_regularizer=regularizers.l2(1e-4),pointwise_regularizer=regularizers.l2(1e-4))(x)
    x = layers.BatchNormalization()(x); x = squeeze_excite_block(x); x = layers.MaxPool2D()(x)

    x = layers.Conv2D(128,3,padding='same',activation='relu',
                      kernel_regularizer=regularizers.l2(1e-4))(x)
    x = layers.BatchNormalization()(x); x = layers.MaxPool2D()(x); x = layers.Dropout(0.3)(x)

    x = layers.Conv2D(256,3,padding='same',activation='relu',
                      kernel_regularizer=regularizers.l2(1e-4))(x)
    x = layers.BatchNormalization()(x)
    feat = layers.GlobalAveragePooling2D()(x)
    pred = layers.Dense(len(classes), activation='softmax')(feat)

    return Model(inp, [feat, pred])
    
class MultiLevelDistiller(tf.keras.Model):
    def __init__(self, student, teachers, temp=5., alpha=0.5, beta=0.5):
        super().__init__()
        self.student = student
        self.teachers = teachers
        self.temp, self.alpha, self.beta = temp, alpha, beta
        self.ce  = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1)
        self.kld = tf.keras.losses.KLDivergence()
        self.mse = tf.keras.losses.MeanSquaredError()
        self.acc = tf.keras.metrics.CategoricalAccuracy()
        
        self.w = self.add_weight(name="tw", shape=(len(teachers),),
                                 initializer="ones", trainable=True)
        
        self.proj = layers.Dense(PROJ_DIM, activation='relu', name='student_proj')

    def compile(self, opt):
        super().compile()
        self.opt = opt

    def train_step(self, data):
        x, y = data
        t_feats, t_preds = [], []
        for t in self.teachers:
            f, p = t(x, training=False)
            t_feats.append(f)
            t_preds.append(p)

        weights = tf.nn.softmax(self.w)
        t_preds_stack = tf.stack(t_preds)
        t_feats_stack = tf.stack(t_feats)
        weights_reshaped = tf.reshape(weights, (-1, 1, 1))

        ft = tf.reduce_sum(weights_reshaped * t_preds_stack, axis=0)
        ff = tf.reduce_sum(weights_reshaped * t_feats_stack, axis=0)

        with tf.GradientTape() as tape:
            s_feat, s_pred = self.student(x, training=True)
            s_proj = self.proj(s_feat)

            loss_ce = self.ce(y, s_pred)
            loss_kd = self.kld(tf.nn.softmax(ft / self.temp), tf.nn.softmax(s_pred / self.temp)) * self.temp**2
            loss_feat = self.mse(ff, s_proj)

            loss = self.alpha * loss_ce + (1 - self.alpha) * (loss_kd + self.beta * loss_feat)

        vars = self.student.trainable_variables + self.proj.trainable_variables + [self.w]
        grads = tape.gradient(loss, vars)
        self.opt.apply_gradients(zip(grads, vars))

        self.acc.update_state(y, s_pred)
        return {"loss": loss, "accuracy": self.acc.result()}

    def test_step(self, data):
        x, y = data
        _, s_pred = self.student(x, training=False)
        loss = self.ce(y, s_pred)
        self.acc.update_state(y, s_pred)
        return {"loss": loss, "accuracy": self.acc.result()}


# -----------------------------
# 3. K‑FOLD CROSS‑VALIDATION
# -----------------------------
kf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
cv_acc = []

for fold, (train_idx, val_idx) in enumerate(kf.split(file_paths, labels), 1):
    print(f"\n=== Fold {fold}/{N_SPLITS} ===")

    x_tr, y_tr = file_paths[train_idx], labels[train_idx]
    x_vl, y_vl = file_paths[val_idx], labels[val_idx]

    y_tr_cat = tf.keras.utils.to_categorical(y_tr, num_classes=len(classes))
    y_vl_cat = tf.keras.utils.to_categorical(y_vl, num_classes=len(classes))

    y_tr_str = [classes[i] for i in y_tr]
    y_vl_str = [classes[i] for i in y_vl]

    train_df = pd.DataFrame({"filename": x_tr, "class": y_tr_str})
    val_df = pd.DataFrame({"filename": x_vl, "class": y_vl_str})

    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=30,
        horizontal_flip=True,
        zoom_range=0.3,
        width_shift_range=0.3,
        height_shift_range=0.3,
        brightness_range=[0.7,1.3],
        shear_range=0.2,
        fill_mode='nearest'
    )
    val_datagen = ImageDataGenerator(rescale=1./255)

    train_gen = train_datagen.flow_from_dataframe(
        train_df, x_col="filename", y_col="class",
        target_size=IMG_SIZE, batch_size=BATCH_SIZE,
        class_mode="categorical", shuffle=True
    )
    val_gen = val_datagen.flow_from_dataframe(
        val_df, x_col="filename", y_col="class",
        target_size=IMG_SIZE, batch_size=BATCH_SIZE,
        class_mode="categorical", shuffle=False
    )

    teachers = [
        build_teacher_model(tf.keras.applications.MobileNetV3Large),
        build_teacher_model(tf.keras.applications.EfficientNetB1),
        build_teacher_model(tf.keras.applications.MobileNetV2),
    ]
    for t in teachers: t.trainable = False
    student = build_student_model()

    distiller = MultiLevelDistiller(student, teachers)
    lr_sched = tf.keras.optimizers.schedules.CosineDecay(INITIAL_LR, decay_steps=10000, alpha=1e-6)
    opt = tf.keras.optimizers.AdamW(learning_rate=lr_sched, weight_decay=1e-4)
    distiller.compile(opt)

    callbacks = [
        tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3, min_lr=1e-6, verbose=1),
        tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True, verbose=1),
    ]

    distiller.fit(train_gen, validation_data=val_gen, epochs=EPOCHS, callbacks=callbacks, verbose=2)

    res = distiller.evaluate(val_gen, verbose=0)
    print(f"Fold {fold} Val accuracy: {res[1]:.4f}")
    cv_acc.append(res[1])

# -----------------------------
# 4. FINAL RESULTS
# -----------------------------
print(f"\nAverage CV Accuracy over {N_SPLITS} folds: {np.mean(cv_acc):.4f}")


Total images: 58464
Classes: ['Apple_Apple Scab', 'Apple_Black Rot', 'Apple_Cedar Apple Rust', 'Apple_Healthy', 'Bell Pepper_Bacterial Spot', 'Bell Pepper_Healthy', 'Cherry_Healthy', 'Cherry_Powdery Mildew', 'Corn (Maize)_Cercospora Leaf Spot', 'Corn (Maize)_Common Rust ', 'Corn (Maize)_Healthy', 'Corn (Maize)_Northern Leaf Blight', 'Grape_Black Rot', 'Grape_Esca (Black Measles)', 'Grape_Healthy', 'Grape_Leaf Blight', 'Peach_Bacterial Spot', 'Peach_Healthy', 'Potato_Early Blight', 'Potato_Healthy', 'Potato_Late Blight', 'Strawberry_Healthy', 'Strawberry_Leaf Scorch', 'Tomato_Bacterial Spot', 'Tomato_Early Blight', 'Tomato_Healthy', 'Tomato_Late Blight', 'Tomato_Septoria Leaf Spot', 'Tomato_Yellow Leaf Curl Virus']

=== Fold 1/5 ===
Found 46771 validated image filenames belonging to 29 classes.
Found 11693 validated image filenames belonging to 29 classes.


  return MobileNetV3(


Epoch 1/50


  self._warn_if_super_not_called()


1462/1462 - 286s - 196ms/step - accuracy: 0.7588 - loss: 0.5229 - val_accuracy: 0.7834 - val_loss: 1.3796 - learning_rate: 0.0010
Epoch 2/50
1462/1462 - 240s - 164ms/step - accuracy: 0.9257 - loss: 0.5189 - val_accuracy: 0.8627 - val_loss: 1.0158 - learning_rate: 0.0010
Epoch 3/50
1462/1462 - 235s - 161ms/step - accuracy: 0.9531 - loss: 0.4547 - val_accuracy: 0.8915 - val_loss: 0.8957 - learning_rate: 0.0010
Epoch 4/50
1462/1462 - 236s - 162ms/step - accuracy: 0.9674 - loss: 0.4365 - val_accuracy: 0.9156 - val_loss: 0.7503 - learning_rate: 0.0010
Epoch 5/50
1462/1462 - 237s - 162ms/step - accuracy: 0.9751 - loss: 0.4117 - val_accuracy: 0.9406 - val_loss: 0.7226 - learning_rate: 0.0010
Epoch 6/50
1462/1462 - 236s - 161ms/step - accuracy: 0.9804 - loss: 0.4379 - val_accuracy: 0.9737 - val_loss: 0.7220 - learning_rate: 0.0010
Epoch 7/50
1462/1462 - 235s - 161ms/step - accuracy: 0.9833 - loss: 0.4144 - val_accuracy: 0.9726 - val_loss: 0.7348 - learning_rate: 0.0010
Epoch 8/50
1462/1462 - 2

  return MobileNetV3(


Epoch 1/50


  self._warn_if_super_not_called()


1462/1462 - 281s - 192ms/step - accuracy: 0.7286 - loss: 0.5588 - val_accuracy: 0.7499 - val_loss: 1.5337 - learning_rate: 0.0010
Epoch 2/50
1462/1462 - 237s - 162ms/step - accuracy: 0.9158 - loss: 0.4330 - val_accuracy: 0.9009 - val_loss: 1.0135 - learning_rate: 0.0010
Epoch 3/50
1462/1462 - 233s - 160ms/step - accuracy: 0.9495 - loss: 0.4621 - val_accuracy: 0.9524 - val_loss: 0.7505 - learning_rate: 0.0010
Epoch 4/50
1462/1462 - 232s - 159ms/step - accuracy: 0.9649 - loss: 0.4814 - val_accuracy: 0.9666 - val_loss: 0.7599 - learning_rate: 0.0010
Epoch 5/50
1462/1462 - 233s - 159ms/step - accuracy: 0.9764 - loss: 0.3878 - val_accuracy: 0.9757 - val_loss: 0.7442 - learning_rate: 0.0010
Epoch 6/50
1462/1462 - 232s - 159ms/step - accuracy: 0.9798 - loss: 0.3953 - val_accuracy: 0.9819 - val_loss: 0.7113 - learning_rate: 0.0010
Epoch 7/50
1462/1462 - 234s - 160ms/step - accuracy: 0.9829 - loss: 0.3793 - val_accuracy: 0.9810 - val_loss: 0.7090 - learning_rate: 0.0010
Epoch 8/50
1462/1462 - 2

  return MobileNetV3(


Epoch 1/50


  self._warn_if_super_not_called()


1462/1462 - 282s - 193ms/step - accuracy: 0.7536 - loss: 0.4994 - val_accuracy: 0.8562 - val_loss: 1.0344 - learning_rate: 0.0010
Epoch 2/50
1462/1462 - 234s - 160ms/step - accuracy: 0.9280 - loss: 0.4393 - val_accuracy: 0.8844 - val_loss: 0.9458 - learning_rate: 0.0010
Epoch 3/50
1462/1462 - 233s - 160ms/step - accuracy: 0.9544 - loss: 0.4360 - val_accuracy: 0.8693 - val_loss: 1.5642 - learning_rate: 0.0010
Epoch 4/50
1462/1462 - 235s - 161ms/step - accuracy: 0.9689 - loss: 0.4611 - val_accuracy: 0.9572 - val_loss: 0.8123 - learning_rate: 0.0010
Epoch 5/50
1462/1462 - 234s - 160ms/step - accuracy: 0.9769 - loss: 0.4229 - val_accuracy: 0.9689 - val_loss: 0.8131 - learning_rate: 0.0010
Epoch 6/50
1462/1462 - 232s - 159ms/step - accuracy: 0.9819 - loss: 0.4190 - val_accuracy: 0.9731 - val_loss: 0.7987 - learning_rate: 0.0010
Epoch 7/50
1462/1462 - 234s - 160ms/step - accuracy: 0.9841 - loss: 0.3697 - val_accuracy: 0.9812 - val_loss: 0.7794 - learning_rate: 0.0010
Epoch 8/50
1462/1462 - 2

  return MobileNetV3(


Epoch 1/50


  self._warn_if_super_not_called()


1462/1462 - 281s - 192ms/step - accuracy: 0.7598 - loss: 0.5096 - val_accuracy: 0.8061 - val_loss: 1.0012 - learning_rate: 0.0010
Epoch 2/50
1462/1462 - 232s - 159ms/step - accuracy: 0.9304 - loss: 0.4105 - val_accuracy: 0.9316 - val_loss: 0.9046 - learning_rate: 0.0010
Epoch 3/50
1462/1462 - 234s - 160ms/step - accuracy: 0.9569 - loss: 0.4365 - val_accuracy: 0.9174 - val_loss: 0.9014 - learning_rate: 0.0010
Epoch 4/50
1462/1462 - 240s - 164ms/step - accuracy: 0.9706 - loss: 0.4244 - val_accuracy: 0.9737 - val_loss: 0.8783 - learning_rate: 0.0010
Epoch 5/50
1462/1462 - 237s - 162ms/step - accuracy: 0.9772 - loss: 0.4060 - val_accuracy: 0.9677 - val_loss: 0.7503 - learning_rate: 0.0010
Epoch 6/50
1462/1462 - 241s - 165ms/step - accuracy: 0.9829 - loss: 0.4107 - val_accuracy: 0.9789 - val_loss: 0.7688 - learning_rate: 0.0010
Epoch 7/50
1462/1462 - 237s - 162ms/step - accuracy: 0.9838 - loss: 0.4441 - val_accuracy: 0.9780 - val_loss: 0.7586 - learning_rate: 0.0010
Epoch 8/50

Epoch 8: Red

  return MobileNetV3(


Epoch 1/50


  self._warn_if_super_not_called()


1462/1462 - 286s - 195ms/step - accuracy: 0.7600 - loss: 0.5421 - val_accuracy: 0.8576 - val_loss: 0.8504 - learning_rate: 0.0010
Epoch 2/50
1462/1462 - 234s - 160ms/step - accuracy: 0.9299 - loss: 0.4390 - val_accuracy: 0.9188 - val_loss: 0.7948 - learning_rate: 0.0010
Epoch 3/50
1462/1462 - 234s - 160ms/step - accuracy: 0.9554 - loss: 0.5009 - val_accuracy: 0.9589 - val_loss: 0.7898 - learning_rate: 0.0010
Epoch 4/50
1462/1462 - 232s - 159ms/step - accuracy: 0.9696 - loss: 0.4142 - val_accuracy: 0.9302 - val_loss: 0.8023 - learning_rate: 0.0010
Epoch 5/50
1462/1462 - 236s - 161ms/step - accuracy: 0.9776 - loss: 0.3897 - val_accuracy: 0.9674 - val_loss: 0.7986 - learning_rate: 0.0010
Epoch 6/50
1462/1462 - 235s - 161ms/step - accuracy: 0.9829 - loss: 0.4142 - val_accuracy: 0.9835 - val_loss: 0.7671 - learning_rate: 0.0010
Epoch 7/50
1462/1462 - 233s - 159ms/step - accuracy: 0.9850 - loss: 0.4210 - val_accuracy: 0.9828 - val_loss: 0.7686 - learning_rate: 0.0010
Epoch 8/50
1462/1462 - 2

In [11]:
import os

final_model_save_path = "/kaggle/working/final_model"
os.makedirs(final_model_save_path, exist_ok=True)

# Save the trained student model from the distiller (after final fold completion)
student_model = distiller.student

# Save the student model
student_model.save(os.path.join(final_model_save_path, "student_model.h5"))

distiller.save(os.path.join(final_model_save_path, "distiller_model.h5"))

print(f"Model saved at {final_model_save_path}")


Model saved at /kaggle/working/final_model
