# Add imports

In [2]:
import tensorflow as tf

from tensorflow.keras.preprocessing import image_dataset_from_directory, image
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import load_img, img_to_array

from matplotlib import pyplot as plt

import numpy as np


In [None]:
print("TensorFlow version:", tf.__version__)
python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"


TensorFlow version: 2.20.0
GPU


# 1. Load dataset

In [5]:
data_dir = "dataset"

img_size = (224, 224)
batch_size = 32

## 1.1. Split dataset into training and validation

In [6]:
train_ds = image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=42,
    image_size=img_size,
    batch_size=batch_size
)

Found 2552 files belonging to 6 classes.
Using 2042 files for training.


2025-08-28 12:59:58.265669: W tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.cc:40] 'cuModuleLoadData(&module, data)' failed with 'CUDA_ERROR_INVALID_PTX'

2025-08-28 12:59:58.265705: W tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.cc:40] 'cuModuleGetFunction(&function, module, kernel_name)' failed with 'CUDA_ERROR_INVALID_HANDLE'

2025-08-28 12:59:58.265715: W tensorflow/core/framework/op_kernel.cc:1842] INTERNAL: 'cuLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, 0, reinterpret_cast<CUstream>(stream), params, nullptr)' failed with 'CUDA_ERROR_INVALID_HANDLE'
2025-08-28 12:59:58.265721: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: INTERNAL: 'cuLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, 0, reinterpret_cast<CUstream>(stream), params, nullptr)' failed with 'CUDA_ERROR_INVALID_HANDLE'


InternalError: {{function_node __wrapped__Equal_device_/job:localhost/replica:0/task:0/device:GPU:0}} 'cuLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, 0, reinterpret_cast<CUstream>(stream), params, nullptr)' failed with 'CUDA_ERROR_INVALID_HANDLE' [Op:Equal] name: 

In [None]:
val_ds = image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=42,
    image_size=img_size,
    batch_size=batch_size
)

In [None]:
class_names = train_ds.class_names
num_classes = len(class_names)

## 1.2. Prefetching to improve performance

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(500).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# 2. Data Augmentation

In [None]:
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
])

# 3. Load Pre-trained MobileNetV2

In [None]:
base_model = MobileNetV2(input_shape=img_size + (3,),
                         include_top=False,
                         weights="imagenet")

base_model.trainable = False

# 4. Build Model

In [None]:
inputs = tf.keras.Input(shape=img_size + (3,))
x = data_augmentation(inputs)
x = tf.keras.applications.mobilenet_v2.preprocess_input(x)  # preprocess for MobileNetV2

x = base_model(x, training=False)  # no BN updates
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.3)(x)  # helps prevent overfitting
outputs = layers.Dense(len(class_names), activation="softmax")(x)

model = models.Model(inputs, outputs)

# 5. Compile

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

model.summary()

# 6. Train

In [None]:
# add early stopping for when val_loss does not improve for a set number of epoch
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True
)

In [None]:
history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=30,
                    callbacks=[early_stopping])

# 7. Fine-tune

In [None]:
base_model.trainable = True
for layer in base_model.layers[:]:  # freeze all but last 40 layers
    layer.trainable = False

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

history_fine = model.fit(train_ds,
                         validation_data=val_ds,
                         epochs=30,
                         callbacks=[early_stopping])

# Show graph

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(30)

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend()
plt.title('Training and Validation Accuracy')


plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()

# Test image

In [None]:

def predict_single_image(img_path, model, class_names, top_k=3):
    # load and prepare image (do NOT call preprocess_input here because the model already does it)
    img = image.load_img(img_path, target_size=(224, 224))
    img_array = image.img_to_array(img)            # shape (224,224,3)
    img_array = np.expand_dims(img_array, axis=0)  # shape (1,224,224,3)

    preds = model.predict(img_array)               # model will preprocess internally
    probs = preds[0]
    top_idx = np.argsort(probs)[-top_k:][::-1]
    return [(class_names[i], float(probs[i])) for i in top_idx]


In [None]:
img_path = "test/glass-bottle.png"
img_vis = image.load_img(img_path, target_size=(224,224))
plt.imshow(img_vis); plt.axis('off'); plt.title("Test image"); plt.show()

print("Top predictions:", predict_single_image(img_path, model, class_names, top_k=3))